1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-07-18 20:59:43 +02:00

still moving codebase to new API (WIP)

This commit is contained in:
Harvey Kandola 2017-07-26 20:03:23 +01:00
parent 72b14def6d
commit d90b3249c3
44 changed files with 5276 additions and 336 deletions

View file

@ -16,11 +16,12 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/net/html"
"github.com/documize/community/core/api/endpoint/models" "github.com/documize/community/core/api/endpoint/models"
"github.com/documize/community/core/api/entity" "github.com/documize/community/core/api/entity"
"github.com/documize/community/core/log" "github.com/documize/community/core/log"
"github.com/documize/community/core/streamutil" "github.com/documize/community/core/streamutil"
"github.com/documize/community/domain/link"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
) )
@ -249,7 +250,7 @@ func (p *Persister) UpdatePage(page entity.Page, refID, userID string, skipRevis
} }
// find any content links in the HTML // find any content links in the HTML
links := link.GetContentLinks(page.Body) links := GetContentLinks(page.Body)
// get a copy of previously saved links // get a copy of previously saved links
previousLinks, _ := p.GetPageLinks(page.DocumentID, page.RefID) previousLinks, _ := p.GetPageLinks(page.DocumentID, page.RefID)
@ -497,3 +498,58 @@ func (p *Persister) GetNextPageSequence(documentID string) (maxSeq float64, err
return return
} }
// GetContentLinks returns Documize generated <a> links.
// such links have an identifying attribute e.g. <a data-documize='true'...
func GetContentLinks(body string) (links []entity.Link) {
z := html.NewTokenizer(strings.NewReader(body))
for {
tt := z.Next()
switch {
case tt == html.ErrorToken:
// End of the document, we're done
return
case tt == html.StartTagToken:
t := z.Token()
// Check if the token is an <a> tag
isAnchor := t.Data == "a"
if !isAnchor {
continue
}
// Extract the content link
ok, link := getLink(t)
if ok {
links = append(links, link)
}
}
}
}
// Helper function to pull the href attribute from a Token
func getLink(t html.Token) (ok bool, link entity.Link) {
ok = false
// Iterate over all of the Token's attributes until we find an "href"
for _, a := range t.Attr {
switch a.Key {
case "data-documize":
ok = true
case "data-link-id":
link.RefID = strings.TrimSpace(a.Val)
case "data-link-space-id":
link.FolderID = strings.TrimSpace(a.Val)
case "data-link-target-document-id":
link.TargetDocumentID = strings.TrimSpace(a.Val)
case "data-link-target-id":
link.TargetID = strings.TrimSpace(a.Val)
case "data-link-type":
link.LinkType = strings.TrimSpace(a.Val)
}
}
return
}

View file

@ -0,0 +1,224 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package attachment
import (
"bytes"
"database/sql"
"fmt"
"io"
"mime"
"net/http"
"github.com/documize/community/core/env"
"github.com/documize/community/core/request"
"github.com/documize/community/core/response"
"github.com/documize/community/core/secrets"
"github.com/documize/community/core/uniqueid"
"github.com/documize/community/domain"
"github.com/documize/community/domain/document"
"github.com/documize/community/model/attachment"
"github.com/documize/community/model/audit"
uuid "github.com/nu7hatch/gouuid"
)
// Handler contains the runtime information such as logging and database.
type Handler struct {
Runtime *env.Runtime
Store *domain.Store
}
// Download is the end-point that responds to a request for a particular attachment
// by sending the requested file to the client.
func (h *Handler) Download(w http.ResponseWriter, r *http.Request) {
method := "attachment.Download"
ctx := domain.GetRequestContext(r)
a, err := h.Store.Attachment.GetAttachment(ctx, request.Param(r, "orgID"), request.Param(r, "attachmentID"))
if err == sql.ErrNoRows {
response.WriteNotFoundError(w, method, request.Param(r, "fileID"))
return
}
if err != nil {
response.WriteServerError(w, method, err)
return
}
typ := mime.TypeByExtension("." + a.Extension)
if typ == "" {
typ = "application/octet-stream"
}
w.Header().Set("Content-Type", typ)
w.Header().Set("Content-Disposition", `Attachment; filename="`+a.Filename+`" ; `+`filename*="`+a.Filename+`"`)
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(a.Data)))
w.WriteHeader(http.StatusOK)
_, err = w.Write(a.Data)
if err != nil {
h.Runtime.Log.Error("writing attachment", err)
return
}
h.Store.Audit.Record(ctx, audit.EventTypeAttachmentDownload)
}
// Get is an end-point that returns all of the attachments of a particular documentID.
func (h *Handler) Get(w http.ResponseWriter, r *http.Request) {
method := "attachment.GetAttachments"
ctx := domain.GetRequestContext(r)
documentID := request.Param(r, "documentID")
if len(documentID) == 0 {
response.WriteMissingDataError(w, method, "documentID")
return
}
if !document.CanViewDocument(ctx, *h.Store, documentID) {
response.WriteForbiddenError(w)
return
}
a, err := h.Store.Attachment.GetAttachments(ctx, documentID)
if err != nil && err != sql.ErrNoRows {
response.WriteServerError(w, method, err)
return
}
if len(a) == 0 {
a = []attachment.Attachment{}
}
response.WriteJSON(w, a)
}
// Delete is an endpoint that deletes a particular document attachment.
func (h *Handler) Delete(w http.ResponseWriter, r *http.Request) {
method := "attachment.DeleteAttachment"
ctx := domain.GetRequestContext(r)
documentID := request.Param(r, "documentID")
if len(documentID) == 0 {
response.WriteMissingDataError(w, method, "documentID")
return
}
attachmentID := request.Param(r, "attachmentID")
if len(attachmentID) == 0 {
response.WriteMissingDataError(w, method, "attachmentID")
return
}
if !document.CanChangeDocument(ctx, *h.Store, documentID) {
response.WriteForbiddenError(w)
return
}
var err error
ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil {
response.WriteServerError(w, method, err)
return
}
_, err = h.Store.Attachment.Delete(ctx, attachmentID)
if err != nil {
ctx.Transaction.Rollback()
response.WriteServerError(w, method, err)
return
}
// Mark references to this document as orphaned
err = h.Store.Link.MarkOrphanAttachmentLink(ctx, attachmentID)
if err != nil {
ctx.Transaction.Rollback()
response.WriteServerError(w, method, err)
return
}
h.Store.Audit.Record(ctx, audit.EventTypeAttachmentDelete)
ctx.Transaction.Commit()
response.WriteEmpty(w)
}
// Add stores files against a document.
func (h *Handler) Add(w http.ResponseWriter, r *http.Request) {
method := "attachment.Add"
ctx := domain.GetRequestContext(r)
documentID := request.Param(r, "documentID")
if len(documentID) == 0 {
response.WriteMissingDataError(w, method, "documentID")
return
}
if !document.CanChangeDocument(ctx, *h.Store, documentID) {
response.WriteForbiddenError(w)
return
}
filedata, filename, err := r.FormFile("attachment")
if err != nil {
response.WriteMissingDataError(w, method, "attachment")
return
}
b := new(bytes.Buffer)
_, err = io.Copy(b, filedata)
if err != nil {
response.WriteServerError(w, method, err)
return
}
var job = "some-uuid"
newUUID, err := uuid.NewV4()
if err != nil {
response.WriteServerError(w, method, err)
return
}
job = newUUID.String()
var a attachment.Attachment
refID := uniqueid.Generate()
a.RefID = refID
a.DocumentID = documentID
a.Job = job
random := secrets.GenerateSalt()
a.FileID = random[0:9]
a.Filename = filename.Filename
a.Data = b.Bytes()
ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil {
response.WriteServerError(w, method, err)
return
}
err = h.Store.Attachment.Add(ctx, a)
if err != nil {
ctx.Transaction.Rollback()
response.WriteServerError(w, method, err)
return
}
h.Store.Audit.Record(ctx, audit.EventTypeAttachmentAdd)
ctx.Transaction.Commit()
response.WriteEmpty(w)
}

View file

@ -0,0 +1,104 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package attachment
import (
"strings"
"time"
"github.com/documize/community/domain"
"github.com/documize/community/domain/store/mysql"
"github.com/pkg/errors"
"github.com/documize/community/core/env"
"github.com/documize/community/core/streamutil"
"github.com/documize/community/model/attachment"
)
// Scope provides data access to MySQL.
type Scope struct {
Runtime *env.Runtime
}
// Add inserts the given record into the database attachement table.
func (s Scope) Add(ctx domain.RequestContext, a attachment.Attachment) (err error) {
a.OrgID = ctx.OrgID
a.Created = time.Now().UTC()
a.Revised = time.Now().UTC()
bits := strings.Split(a.Filename, ".")
a.Extension = bits[len(bits)-1]
stmt, err := ctx.Transaction.Preparex("INSERT INTO attachment (refid, orgid, documentid, job, fileid, filename, data, extension, created, revised) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
defer streamutil.Close(stmt)
if err != nil {
err = errors.Wrap(err, "prepare insert attachment")
return
}
_, err = stmt.Exec(a.RefID, a.OrgID, a.DocumentID, a.Job, a.FileID, a.Filename, a.Data, a.Extension, a.Created, a.Revised)
if err != nil {
err = errors.Wrap(err, "execute insert attachment")
return
}
return
}
// GetAttachment returns the database attachment record specified by the parameters.
func (s Scope) GetAttachment(ctx domain.RequestContext, orgID, attachmentID string) (a attachment.Attachment, err error) {
stmt, err := s.Runtime.Db.Preparex("SELECT id, refid, orgid, documentid, job, fileid, filename, data, extension, created, revised FROM attachment WHERE orgid=? and refid=?")
defer streamutil.Close(stmt)
if err != nil {
err = errors.Wrap(err, "prepare select attachment")
return
}
err = stmt.Get(&a, orgID, attachmentID)
if err != nil {
err = errors.Wrap(err, "execute select attachment")
return
}
return
}
// GetAttachments returns a slice containing the attachement records (excluding their data) for document docID, ordered by filename.
func (s Scope) GetAttachments(ctx domain.RequestContext, docID string) (a []attachment.Attachment, err error) {
err = s.Runtime.Db.Select(&a, "SELECT id, refid, orgid, documentid, job, fileid, filename, extension, created, revised FROM attachment WHERE orgid=? and documentid=? order by filename", ctx.OrgID, docID)
if err != nil {
err = errors.Wrap(err, "execute select attachments")
return
}
return
}
// GetAttachmentsWithData returns a slice containing the attachement records (including their data) for document docID, ordered by filename.
func (s Scope) GetAttachmentsWithData(ctx domain.RequestContext, docID string) (a []attachment.Attachment, err error) {
err = s.Runtime.Db.Select(&a, "SELECT id, refid, orgid, documentid, job, fileid, filename, extension, data, created, revised FROM attachment WHERE orgid=? and documentid=? order by filename", ctx.OrgID, docID)
if err != nil {
err = errors.Wrap(err, "execute select attachments with data")
return
}
return
}
// Delete deletes the id record from the database attachment table.
func (s Scope) Delete(ctx domain.RequestContext, id string) (rows int64, err error) {
b := mysql.BaseQuery{}
return b.DeleteConstrained(ctx.Transaction, "attachment", ctx.OrgID, id)
}

View file

@ -11,13 +11,34 @@
package auth package auth
/* import (
// Authenticate user based up HTTP Authorization header. "database/sql"
// An encrypted authentication token is issued with an expiry date. "errors"
func (h *Handler) Authenticate(w http.ResponseWriter, r *http.Request) { "net/http"
method := "Authenticate" "strings"
s := domain.StoreContext{Runtime: h.Runtime, Context: domain.GetRequestContext(r)} "github.com/documize/community/core/env"
"github.com/documize/community/core/response"
"github.com/documize/community/core/secrets"
"github.com/documize/community/domain"
"github.com/documize/community/domain/organization"
"github.com/documize/community/domain/section/provider"
"github.com/documize/community/domain/user"
"github.com/documize/community/model/auth"
"github.com/documize/community/model/org"
)
// Handler contains the runtime information such as logging and database.
type Handler struct {
Runtime *env.Runtime
Store *domain.Store
}
// Login user based up HTTP Authorization header.
// An encrypted authentication token is issued with an expiry date.
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
method := "auth.Login"
ctx := domain.GetRequestContext(r)
// check for http header // check for http header
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
@ -46,23 +67,20 @@ func (h *Handler) Authenticate(w http.ResponseWriter, r *http.Request) {
} }
dom := strings.TrimSpace(strings.ToLower(credentials[0])) dom := strings.TrimSpace(strings.ToLower(credentials[0]))
dom = organization.CheckDomain(s, dom) // TODO optimize by removing this once js allows empty domains dom = h.Store.Organization.CheckDomain(ctx, dom) // TODO optimize by removing this once js allows empty domains
email := strings.TrimSpace(strings.ToLower(credentials[1])) email := strings.TrimSpace(strings.ToLower(credentials[1]))
password := credentials[2] password := credentials[2]
h.Runtime.Log.Info("logon attempt " + email + " @ " + dom) h.Runtime.Log.Info("logon attempt " + email + " @ " + dom)
u, err := user.GetByDomain(s, dom, email) u, err := h.Store.User.GetByDomain(ctx, dom, email)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
response.WriteUnauthorizedError(w) response.WriteUnauthorizedError(w)
return return
} }
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
if len(u.Reset) > 0 || len(u.Password) == 0 { if len(u.Reset) > 0 || len(u.Password) == 0 {
response.WriteUnauthorizedError(w) response.WriteUnauthorizedError(w)
return return
@ -74,31 +92,29 @@ func (h *Handler) Authenticate(w http.ResponseWriter, r *http.Request) {
return return
} }
org, err := organization.GetOrganizationByDomain(s, dom) org, err := h.Store.Organization.GetOrganizationByDomain(ctx, dom)
if err != nil { if err != nil {
response.WriteUnauthorizedError(w) response.WriteUnauthorizedError(w)
return return
} }
// Attach user accounts and work out permissions // Attach user accounts and work out permissions
user.AttachUserAccounts(s, org.RefID, &u) user.AttachUserAccounts(ctx, *h.Store, org.RefID, &u)
// active check
if len(u.Accounts) == 0 { if len(u.Accounts) == 0 {
response.WriteUnauthorizedError(w) response.WriteUnauthorizedError(w)
return return
} }
authModel := AuthenticationModel{} authModel := auth.AuthenticationModel{}
authModel.Token = GenerateJWT(h.Runtime, u.RefID, org.RefID, dom) authModel.Token = GenerateJWT(h.Runtime, u.RefID, org.RefID, dom)
authModel.User = u authModel.User = u
response.WriteJSON(w, authModel) response.WriteJSON(w, authModel)
} }
// ValidateAuthToken finds and validates authentication token. // ValidateToken finds and validates authentication token.
func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) { func (h *Handler) ValidateToken(w http.ResponseWriter, r *http.Request) {
// TODO should this go after token validation? // TODO should this go after token validation?
if s := r.URL.Query().Get("section"); s != "" { if s := r.URL.Query().Get("section"); s != "" {
if err := provider.Callback(s, w, r); err != nil { if err := provider.Callback(s, w, r); err != nil {
@ -109,40 +125,40 @@ func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) {
return return
} }
s := domain.StoreContext{Runtime: h.Runtime, Context: domain.GetRequestContext(r)}
token := FindJWT(r) token := FindJWT(r)
rc, _, tokenErr := DecodeJWT(h.Runtime, token) rc, _, tokenErr := DecodeJWT(h.Runtime, token)
var org = organization.Organization{} var org = org.Organization{}
var err = errors.New("") var err = errors.New("")
// We always grab the org record regardless of token status. // We always grab the org record regardless of token status.
// Why? If bad token we might be OK to alow anonymous access // Why? If bad token we might be OK to alow anonymous access
// depending upon the domain in question. // depending upon the domain in question.
if len(rc.OrgID) == 0 { if len(rc.OrgID) == 0 {
org, err = organization.GetOrganizationByDomain(s, organization.GetRequestSubdomain(s, r)) dom := organization.GetRequestSubdomain(r)
org, err = h.Store.Organization.GetOrganizationByDomain(rc, dom)
} else { } else {
org, err = organization.GetOrganization(s, rc.OrgID) org, err = h.Store.Organization.GetOrganization(rc, rc.OrgID)
} }
rc.Subdomain = org.Domain rc.Subdomain = org.Domain
// Inability to find org record spells the end of this request. // Inability to find org record spells the end of this request.
if err != nil { if err != nil {
w.WriteHeader(http.StatusUnauthorized) response.WriteUnauthorizedError(w)
return return
} }
// If we have bad auth token and the domain does not allow anon access // If we have bad auth token and the domain does not allow anon access
if !org.AllowAnonymousAccess && tokenErr != nil { if !org.AllowAnonymousAccess && tokenErr != nil {
response.WriteUnauthorizedError(w)
return return
} }
dom := organization.GetSubdomainFromHost(s, r) dom := organization.GetSubdomainFromHost(r)
dom2 := organization.GetRequestSubdomain(s, r) dom2 := organization.GetRequestSubdomain(r)
if org.Domain != dom && org.Domain != dom2 { if org.Domain != dom && org.Domain != dom2 {
w.WriteHeader(http.StatusUnauthorized) response.WriteUnauthorizedError(w)
return return
} }
@ -152,7 +168,7 @@ func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) {
// So you have a bad token // So you have a bad token
if len(token) > 1 { if len(token) > 1 {
if tokenErr != nil { if tokenErr != nil {
w.WriteHeader(http.StatusUnauthorized) response.WriteUnauthorizedError(w)
return return
} }
} else { } else {
@ -170,18 +186,18 @@ func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) {
rc.Editor = false rc.Editor = false
rc.Global = false rc.Global = false
rc.AppURL = r.Host rc.AppURL = r.Host
rc.Subdomain = organization.GetSubdomainFromHost(s, r) rc.Subdomain = organization.GetSubdomainFromHost(r)
rc.SSL = r.TLS != nil rc.SSL = r.TLS != nil
// Fetch user permissions for this org // Fetch user permissions for this org
if !rc.Authenticated { if !rc.Authenticated {
w.WriteHeader(http.StatusUnauthorized) response.WriteUnauthorizedError(w)
return return
} }
u, err := user.GetSecuredUser(s, org.RefID, rc.UserID) u, err := user.GetSecuredUser(rc, *h.Store, org.RefID, rc.UserID)
if err != nil { if err != nil {
w.WriteHeader(http.StatusUnauthorized) response.WriteUnauthorizedError(w)
return return
} }
@ -190,6 +206,4 @@ func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) {
rc.Global = u.Global rc.Global = u.Global
response.WriteJSON(w, u) response.WriteJSON(w, u)
return
} }
*/

49
domain/auth/keycloak.go Normal file
View file

@ -0,0 +1,49 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package auth
import (
"encoding/json"
"github.com/documize/community/core/env"
"github.com/documize/community/model/auth"
)
// StripAuthSecrets removes sensitive data from auth provider configuration
func StripAuthSecrets(r *env.Runtime, provider, config string) string {
switch provider {
case "documize":
return config
break
case "keycloak":
c := auth.KeycloakConfig{}
err := json.Unmarshal([]byte(config), &c)
if err != nil {
r.Log.Error("StripAuthSecrets", err)
return config
}
c.AdminPassword = ""
c.AdminUser = ""
c.PublicKey = ""
j, err := json.Marshal(c)
if err != nil {
r.Log.Error("StripAuthSecrets", err)
return config
}
return string(j)
break
}
return config
}

View file

@ -17,6 +17,7 @@ import (
"github.com/documize/community/core/env" "github.com/documize/community/core/env"
"github.com/documize/community/core/streamutil" "github.com/documize/community/core/streamutil"
"github.com/documize/community/domain" "github.com/documize/community/domain"
"github.com/documize/community/model/doc"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -25,6 +26,25 @@ type Scope struct {
Runtime *env.Runtime Runtime *env.Runtime
} }
// Get fetches the document record with the given id fromt the document table and audits that it has been got.
func (s Scope) Get(ctx domain.RequestContext, id string) (document doc.Document, err error) {
stmt, err := s.Runtime.Db.Preparex("SELECT id, refid, orgid, labelid, userid, job, location, title, excerpt, slug, tags, template, layout, created, revised FROM document WHERE orgid=? and refid=?")
defer streamutil.Close(stmt)
if err != nil {
err = errors.Wrap(err, "prepare select document")
return
}
err = stmt.Get(&document, ctx.OrgID, id)
if err != nil {
err = errors.Wrap(err, "execute select document")
return
}
return
}
// MoveDocumentSpace changes the label for client's organization's documents which have space "id", to "move". // MoveDocumentSpace changes the label for client's organization's documents which have space "id", to "move".
func (s Scope) MoveDocumentSpace(ctx domain.RequestContext, id, move string) (err error) { func (s Scope) MoveDocumentSpace(ctx domain.RequestContext, id, move string) (err error) {
stmt, err := ctx.Transaction.Preparex("UPDATE document SET labelid=? WHERE orgid=? AND labelid=?") stmt, err := ctx.Transaction.Preparex("UPDATE document SET labelid=? WHERE orgid=? AND labelid=?")
@ -43,3 +63,20 @@ func (s Scope) MoveDocumentSpace(ctx domain.RequestContext, id, move string) (er
return return
} }
// PublicDocuments returns a slice of SitemapDocument records, holding documents in folders of type 1 (entity.TemplateTypePublic).
func (s Scope) PublicDocuments(ctx domain.RequestContext, orgID string) (documents []doc.SitemapDocument, err error) {
err = s.Runtime.Db.Select(&documents,
`SELECT d.refid as documentid, d.title as document, d.revised as revised, l.refid as folderid, l.label as folder
FROM document d LEFT JOIN label l ON l.refid=d.labelid
WHERE d.orgid=?
AND l.type=1
AND d.template=0`, orgID)
if err != nil {
err = errors.Wrap(err, fmt.Sprintf("execute GetPublicDocuments for org %s%s", orgID))
return
}
return
}

View file

@ -0,0 +1,116 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package document
import (
"database/sql"
"github.com/documize/community/domain"
)
// CanViewDocumentInFolder returns if the user has permission to view a document within the specified folder.
func CanViewDocumentInFolder(ctx domain.RequestContext, s domain.Store, labelID string) (hasPermission bool) {
roles, err := s.Space.GetUserRoles(ctx)
if err == sql.ErrNoRows {
err = nil
}
if err != nil {
return false
}
for _, role := range roles {
if role.LabelID == labelID && (role.CanView || role.CanEdit) {
return true
}
}
return false
}
// CanViewDocument returns if the clinet has permission to view a given document.
func CanViewDocument(ctx domain.RequestContext, s domain.Store, documentID string) (hasPermission bool) {
document, err := s.Document.Get(ctx, documentID)
if err == sql.ErrNoRows {
err = nil
}
if err != nil {
return false
}
roles, err := s.Space.GetUserRoles(ctx)
if err == sql.ErrNoRows {
err = nil
}
if err != nil {
return false
}
for _, role := range roles {
if role.LabelID == document.LabelID && (role.CanView || role.CanEdit) {
return true
}
}
return false
}
// CanChangeDocument returns if the clinet has permission to change a given document.
func CanChangeDocument(ctx domain.RequestContext, s domain.Store, documentID string) (hasPermission bool) {
document, err := s.Document.Get(ctx, documentID)
if err == sql.ErrNoRows {
err = nil
}
if err != nil {
return false
}
roles, err := s.Space.GetUserRoles(ctx)
if err == sql.ErrNoRows {
err = nil
}
if err != nil {
return false
}
for _, role := range roles {
if role.LabelID == document.LabelID && role.CanEdit {
return true
}
}
return false
}
// CanUploadDocument returns if the client has permission to upload documents to the given folderID.
func CanUploadDocument(ctx domain.RequestContext, s domain.Store, folderID string) (hasPermission bool) {
roles, err := s.Space.GetUserRoles(ctx)
if err == sql.ErrNoRows {
err = nil
}
if err != nil {
return false
}
for _, role := range roles {
if role.LabelID == folderID && role.CanEdit {
return true
}
}
return false
}

160
domain/link/endpoint.go Normal file
View file

@ -0,0 +1,160 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package link
import (
"database/sql"
"net/http"
"net/url"
"github.com/documize/community/core/env"
"github.com/documize/community/core/request"
"github.com/documize/community/core/response"
"github.com/documize/community/core/uniqueid"
"github.com/documize/community/domain"
"github.com/documize/community/domain/document"
"github.com/documize/community/model/attachment"
"github.com/documize/community/model/link"
"github.com/documize/community/model/page"
)
// Handler contains the runtime information such as logging and database.
type Handler struct {
Runtime *env.Runtime
Store *domain.Store
}
// GetLinkCandidates returns references to documents/sections/attachments.
func (h *Handler) GetLinkCandidates(w http.ResponseWriter, r *http.Request) {
method := "link.Candidates"
ctx := domain.GetRequestContext(r)
folderID := request.Param(r, "folderID")
if len(folderID) == 0 {
response.WriteMissingDataError(w, method, "folderID")
return
}
documentID := request.Param(r, "documentID")
if len(documentID) == 0 {
response.WriteMissingDataError(w, method, "documentID")
return
}
pageID := request.Param(r, "pageID")
if len(pageID) == 0 {
response.WriteMissingDataError(w, method, "pageID")
return
}
// permission check
if document.CanViewDocument(ctx, *h.Store, documentID) {
response.WriteForbiddenError(w)
return
}
// We can link to a section within the same document so
// let's get all pages for the document and remove "us".
pages, err := h.Store.Page.GetPagesWithoutContent(ctx, documentID)
if err != nil && err != sql.ErrNoRows {
response.WriteServerError(w, method, err)
return
}
if len(pages) == 0 {
pages = []page.Page{}
}
pc := []link.Candidate{}
for _, p := range pages {
if p.RefID != pageID {
c := link.Candidate{
RefID: uniqueid.Generate(),
FolderID: folderID,
DocumentID: documentID,
TargetID: p.RefID,
LinkType: p.PageType,
Title: p.Title,
}
pc = append(pc, c)
}
}
// We can link to attachment within the same document so
// let's get all attachments for the document.
files, err := h.Store.Attachment.GetAttachments(ctx, documentID)
if err != nil && err != sql.ErrNoRows {
response.WriteServerError(w, method, err)
return
}
if len(files) == 0 {
files = []attachment.Attachment{}
}
fc := []link.Candidate{}
for _, f := range files {
c := link.Candidate{
RefID: uniqueid.Generate(),
FolderID: folderID,
DocumentID: documentID,
TargetID: f.RefID,
LinkType: "file",
Title: f.Filename,
Context: f.Extension,
}
fc = append(fc, c)
}
var payload struct {
Pages []link.Candidate `json:"pages"`
Attachments []link.Candidate `json:"attachments"`
}
payload.Pages = pc
payload.Attachments = fc
response.WriteJSON(w, payload)
}
// SearchLinkCandidates endpoint takes a list of keywords and returns a list of document references matching those keywords.
func (h *Handler) SearchLinkCandidates(w http.ResponseWriter, r *http.Request) {
method := "link.SearchLinkCandidates"
ctx := domain.GetRequestContext(r)
keywords := request.Query(r, "keywords")
decoded, err := url.QueryUnescape(keywords)
if err != nil {
h.Runtime.Log.Error("decode query string", err)
}
docs, pages, attachments, err := h.Store.Link.SearchCandidates(ctx, decoded)
if err != nil {
response.WriteServerError(w, method, err)
return
}
var payload struct {
Documents []link.Candidate `json:"documents"`
Pages []link.Candidate `json:"pages"`
Attachments []link.Candidate `json:"attachments"`
}
payload.Documents = docs
payload.Pages = pages
payload.Attachments = attachments
response.WriteJSON(w, payload)
}

View file

@ -14,13 +14,13 @@ package link
import ( import (
"strings" "strings"
"github.com/documize/community/core/api/entity" "github.com/documize/community/model/link"
"golang.org/x/net/html" "golang.org/x/net/html"
) )
// GetContentLinks returns Documize generated <a> links. // GetContentLinks returns Documize generated <a> links.
// such links have an identifying attribute e.g. <a data-documize='true'... // such links have an identifying attribute e.g. <a data-documize='true'...
func GetContentLinks(body string) (links []entity.Link) { func GetContentLinks(body string) (links []link.Link) {
z := html.NewTokenizer(strings.NewReader(body)) z := html.NewTokenizer(strings.NewReader(body))
for { for {
@ -49,7 +49,7 @@ func GetContentLinks(body string) (links []entity.Link) {
} }
// Helper function to pull the href attribute from a Token // Helper function to pull the href attribute from a Token
func getLink(t html.Token) (ok bool, link entity.Link) { func getLink(t html.Token) (ok bool, link link.Link) {
ok = false ok = false
// Iterate over all of the Token's attributes until we find an "href" // Iterate over all of the Token's attributes until we find an "href"

292
domain/link/mysql/store.go Normal file
View file

@ -0,0 +1,292 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package link
import (
"fmt"
"time"
"github.com/documize/community/core/env"
"github.com/documize/community/core/streamutil"
"github.com/documize/community/core/uniqueid"
"github.com/documize/community/domain"
"github.com/documize/community/domain/store/mysql"
"github.com/documize/community/model/link"
"github.com/pkg/errors"
)
// Scope provides data access to MySQL.
type Scope struct {
Runtime *env.Runtime
}
// Add inserts wiki-link into the store.
// These links exist when content references another document or content.
func (s Scope) Add(ctx domain.RequestContext, l link.Link) (err error) {
l.Created = time.Now().UTC()
l.Revised = time.Now().UTC()
stmt, err := ctx.Transaction.Preparex("INSERT INTO link (refid, orgid, folderid, userid, sourcedocumentid, sourcepageid, targetdocumentid, targetid, linktype, orphan, created, revised) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
defer streamutil.Close(stmt)
if err != nil {
err = errors.Wrap(err, "prepare link insert")
return
}
_, err = stmt.Exec(l.RefID, l.OrgID, l.FolderID, l.UserID, l.SourceDocumentID, l.SourcePageID, l.TargetDocumentID, l.TargetID, l.LinkType, l.Orphan, l.Created, l.Revised)
if err != nil {
err = errors.Wrap(err, "execute link insert")
return
}
return
}
// GetDocumentOutboundLinks returns outbound links for specified document.
func (s Scope) GetDocumentOutboundLinks(ctx domain.RequestContext, documentID string) (links []link.Link, err error) {
err = s.Runtime.Db.Select(&links,
`select l.refid, l.orgid, l.folderid, l.userid, l.sourcedocumentid, l.sourcepageid, l.targetdocumentid, l.targetid, l.linktype, l.orphan, l.created, l.revised
FROM link l
WHERE l.orgid=? AND l.sourcedocumentid=?`,
ctx.OrgID,
documentID)
if err != nil {
return
}
if len(links) == 0 {
links = []link.Link{}
}
return
}
// GetPageLinks returns outbound links for specified page in document.
func (s Scope) GetPageLinks(ctx domain.RequestContext, documentID, pageID string) (links []link.Link, err error) {
err = s.Runtime.Db.Select(&links,
`select l.refid, l.orgid, l.folderid, l.userid, l.sourcedocumentid, l.sourcepageid, l.targetdocumentid, l.targetid, l.linktype, l.orphan, l.created, l.revised
FROM link l
WHERE l.orgid=? AND l.sourcedocumentid=? AND l.sourcepageid=?`,
ctx.OrgID,
documentID,
pageID)
if err != nil {
return
}
if len(links) == 0 {
links = []link.Link{}
}
return
}
// MarkOrphanDocumentLink marks all link records referencing specified document.
func (s Scope) MarkOrphanDocumentLink(ctx domain.RequestContext, documentID string) (err error) {
revised := time.Now().UTC()
stmt, err := ctx.Transaction.Preparex("UPDATE link SET orphan=1, revised=? WHERE linktype='document' AND orgid=? AND targetdocumentid=?")
defer streamutil.Close(stmt)
if err != nil {
return
}
_, err = stmt.Exec(revised, ctx.OrgID, documentID)
return
}
// MarkOrphanPageLink marks all link records referencing specified page.
func (s Scope) MarkOrphanPageLink(ctx domain.RequestContext, pageID string) (err error) {
revised := time.Now().UTC()
stmt, err := ctx.Transaction.Preparex("UPDATE link SET orphan=1, revised=? WHERE linktype='section' AND orgid=? AND targetid=?")
defer streamutil.Close(stmt)
if err != nil {
return
}
_, err = stmt.Exec(revised, ctx.OrgID, pageID)
return
}
// MarkOrphanAttachmentLink marks all link records referencing specified attachment.
func (s Scope) MarkOrphanAttachmentLink(ctx domain.RequestContext, attachmentID string) (err error) {
revised := time.Now().UTC()
stmt, err := ctx.Transaction.Preparex("UPDATE link SET orphan=1, revised=? WHERE linktype='file' AND orgid=? AND targetid=?")
defer streamutil.Close(stmt)
if err != nil {
return
}
_, err = stmt.Exec(revised, ctx.OrgID, attachmentID)
return
}
// DeleteSourcePageLinks removes saved links for given source.
func (s Scope) DeleteSourcePageLinks(ctx domain.RequestContext, pageID string) (rows int64, err error) {
b := mysql.BaseQuery{}
return b.DeleteWhere(ctx.Transaction, fmt.Sprintf("DELETE FROM link WHERE orgid=\"%s\" AND sourcepageid=\"%s\"", ctx.OrgID, pageID))
}
// DeleteSourceDocumentLinks removes saved links for given document.
func (s Scope) DeleteSourceDocumentLinks(ctx domain.RequestContext, documentID string) (rows int64, err error) {
b := mysql.BaseQuery{}
return b.DeleteWhere(ctx.Transaction, fmt.Sprintf("DELETE FROM link WHERE orgid=\"%s\" AND sourcedocumentid=\"%s\"", ctx.OrgID, documentID))
}
// DeleteLink removes saved link from the store.
func (s Scope) DeleteLink(ctx domain.RequestContext, id string) (rows int64, err error) {
b := mysql.BaseQuery{}
return b.DeleteConstrained(ctx.Transaction, "link", ctx.OrgID, id)
}
// SearchCandidates returns matching documents, sections and attachments using keywords.
func (s Scope) SearchCandidates(ctx domain.RequestContext, keywords string) (docs []link.Candidate,
pages []link.Candidate, attachments []link.Candidate, err error) {
// find matching documents
temp := []link.Candidate{}
likeQuery := "title LIKE '%" + keywords + "%'"
err = s.Runtime.Db.Select(&temp,
`SELECT refid as documentid, labelid as folderid,title from document WHERE orgid=? AND `+likeQuery+` AND labelid IN
(SELECT refid from label WHERE orgid=? AND type=2 AND userid=?
UNION ALL SELECT refid FROM label a where orgid=? AND type=1 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid='' AND (canedit=1 OR canview=1))
UNION ALL SELECT refid FROM label a where orgid=? AND type=3 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid=? AND (canedit=1 OR canview=1)))
ORDER BY title`,
ctx.OrgID,
ctx.OrgID,
ctx.UserID,
ctx.OrgID,
ctx.OrgID,
ctx.OrgID,
ctx.OrgID,
ctx.UserID)
if err != nil {
err = errors.Wrap(err, "execute search links 1")
return
}
for _, r := range temp {
c := link.Candidate{
RefID: uniqueid.Generate(),
FolderID: r.FolderID,
DocumentID: r.DocumentID,
TargetID: r.DocumentID,
LinkType: "document",
Title: r.Title,
Context: "",
}
docs = append(docs, c)
}
// find matching sections
likeQuery = "p.title LIKE '%" + keywords + "%'"
temp = []link.Candidate{}
err = s.Runtime.Db.Select(&temp,
`SELECT p.refid as targetid, p.documentid as documentid, p.title as title, p.pagetype as linktype, d.title as context, d.labelid as folderid from page p
LEFT JOIN document d ON d.refid=p.documentid WHERE p.orgid=? AND `+likeQuery+` AND d.labelid IN
(SELECT refid from label WHERE orgid=? AND type=2 AND userid=?
UNION ALL SELECT refid FROM label a where orgid=? AND type=1 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid='' AND (canedit=1 OR canview=1))
UNION ALL SELECT refid FROM label a where orgid=? AND type=3 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid=? AND (canedit=1 OR canview=1)))
ORDER BY p.title`,
ctx.OrgID,
ctx.OrgID,
ctx.UserID,
ctx.OrgID,
ctx.OrgID,
ctx.OrgID,
ctx.OrgID,
ctx.UserID)
if err != nil {
err = errors.Wrap(err, "execute search links 2")
return
}
for _, r := range temp {
c := link.Candidate{
RefID: uniqueid.Generate(),
FolderID: r.FolderID,
DocumentID: r.DocumentID,
TargetID: r.TargetID,
LinkType: r.LinkType,
Title: r.Title,
Context: r.Context,
}
pages = append(pages, c)
}
// find matching attachments
likeQuery = "a.filename LIKE '%" + keywords + "%'"
temp = []link.Candidate{}
err = s.Runtime.Db.Select(&temp,
`SELECT a.refid as targetid, a.documentid as documentid, a.filename as title, a.extension as context, d.labelid as folderid from attachment a
LEFT JOIN document d ON d.refid=a.documentid WHERE a.orgid=? AND `+likeQuery+` AND d.labelid IN
(SELECT refid from label WHERE orgid=? AND type=2 AND userid=?
UNION ALL SELECT refid FROM label a where orgid=? AND type=1 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid='' AND (canedit=1 OR canview=1))
UNION ALL SELECT refid FROM label a where orgid=? AND type=3 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid=? AND (canedit=1 OR canview=1)))
ORDER BY a.filename`,
ctx.OrgID,
ctx.OrgID,
ctx.UserID,
ctx.OrgID,
ctx.OrgID,
ctx.OrgID,
ctx.OrgID,
ctx.UserID)
if err != nil {
err = errors.Wrap(err, "execute search links 3")
return
}
for _, r := range temp {
c := link.Candidate{
RefID: uniqueid.Generate(),
FolderID: r.FolderID,
DocumentID: r.DocumentID,
TargetID: r.TargetID,
LinkType: "file",
Title: r.Title,
Context: r.Context,
}
attachments = append(attachments, c)
}
if len(docs) == 0 {
docs = []link.Candidate{}
}
if len(pages) == 0 {
pages = []link.Candidate{}
}
if len(attachments) == 0 {
attachments = []link.Candidate{}
}
return
}

182
domain/meta/endpoint.go Normal file
View file

@ -0,0 +1,182 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package meta
import (
"bytes"
"fmt"
"net/http"
"text/template"
"github.com/documize/community/core/env"
"github.com/documize/community/core/log"
"github.com/documize/community/core/response"
"github.com/documize/community/core/stringutil"
"github.com/documize/community/domain"
"github.com/documize/community/domain/auth"
"github.com/documize/community/domain/organization"
"github.com/documize/community/model/doc"
"github.com/documize/community/model/org"
"github.com/documize/community/model/space"
)
// Handler contains the runtime information such as logging and database.
type Handler struct {
Runtime *env.Runtime
Store *domain.Store
}
// Meta provides org meta data based upon request domain (e.g. acme.documize.com).
func (h *Handler) Meta(w http.ResponseWriter, r *http.Request) {
ctx := domain.GetRequestContext(r)
data := org.SiteMeta{}
data.URL = organization.GetSubdomainFromHost(r)
org, err := h.Store.Organization.GetOrganizationByDomain(ctx, data.URL)
if err != nil {
response.WriteForbiddenError(w)
return
}
data.OrgID = org.RefID
data.Title = org.Title
data.Message = org.Message
data.AllowAnonymousAccess = org.AllowAnonymousAccess
data.AuthProvider = org.AuthProvider
data.AuthConfig = org.AuthConfig
data.Version = h.Runtime.Product.Version
data.Edition = h.Runtime.Product.License.Edition
data.Valid = h.Runtime.Product.License.Valid
data.ConversionEndpoint = org.ConversionEndpoint
// Strip secrets
data.AuthConfig = auth.StripAuthSecrets(h.Runtime, org.AuthProvider, org.AuthConfig)
response.WriteJSON(w, data)
}
// RobotsTxt returns robots.txt depending on site configuration.
// Did we allow anonymouse access?
func (h *Handler) RobotsTxt(w http.ResponseWriter, r *http.Request) {
method := "GetRobots"
ctx := domain.GetRequestContext(r)
dom := organization.GetSubdomainFromHost(r)
org, err := h.Store.Organization.GetOrganizationByDomain(ctx, dom)
// default is to deny
robots :=
`User-agent: *
Disallow: /
`
if err != nil {
h.Runtime.Log.Error(fmt.Sprintf("%s failed to get Organization for domain %s", method, dom), err)
}
// Anonymous access would mean we allow bots to crawl.
if org.AllowAnonymousAccess {
sitemap := ctx.GetAppURL("sitemap.xml")
robots = fmt.Sprintf(
`User-agent: *
Disallow: /settings/
Disallow: /settings/*
Disallow: /profile/
Disallow: /profile/*
Disallow: /auth/login/
Disallow: /auth/login/
Disallow: /auth/logout/
Disallow: /auth/logout/*
Disallow: /auth/reset/*
Disallow: /auth/reset/*
Disallow: /auth/sso/
Disallow: /auth/sso/*
Disallow: /share
Disallow: /share/*
Sitemap: %s`, sitemap)
}
response.WriteBytes(w, []byte(robots))
}
// Sitemap returns URLs that can be indexed.
// We only include public folders and documents (e.g. can be seen by everyone).
func (h *Handler) Sitemap(w http.ResponseWriter, r *http.Request) {
method := "meta.Sitemap"
ctx := domain.GetRequestContext(r)
dom := organization.GetSubdomainFromHost(r)
org, err := h.Store.Organization.GetOrganizationByDomain(ctx, dom)
if err != nil {
h.Runtime.Log.Error(fmt.Sprintf("%s failed to get Organization for domain %s", method, dom), err)
}
sitemap :=
`<?xml version="1.0" encoding="UTF-8"?>
<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.sitemaps.org/schemas/sitemap/0.9 http://www.sitemaps.org/schemas/sitemap/0.9/sitemap.xsd">
{{range .}}<url>
<loc>{{ .URL }}</loc>
<lastmod>{{ .Date }}</lastmod>
</url>{{end}}
</urlset>`
var items []sitemapItem
// Anonymous access means we announce folders/documents shared with 'Everyone'.
if org.AllowAnonymousAccess {
// Grab shared folders
folders, err := h.Store.Space.PublicSpaces(ctx, org.RefID)
if err != nil {
folders = []space.Space{}
h.Runtime.Log.Error(fmt.Sprintf("%s failed to get folders for domain %s", method, dom), err)
}
for _, folder := range folders {
var item sitemapItem
item.URL = ctx.GetAppURL(fmt.Sprintf("s/%s/%s", folder.RefID, stringutil.MakeSlug(folder.Name)))
item.Date = folder.Revised.Format("2006-01-02T15:04:05.999999-07:00")
items = append(items, item)
}
// Grab documents from shared folders
var documents []doc.SitemapDocument
documents, err = h.Store.Document.PublicDocuments(ctx, org.RefID)
if err != nil {
documents = []doc.SitemapDocument{}
h.Runtime.Log.Error(fmt.Sprintf("%s failed to get documents for domain %s", method, dom), err)
}
for _, document := range documents {
var item sitemapItem
item.URL = ctx.GetAppURL(fmt.Sprintf("s/%s/%s/d/%s/%s",
document.FolderID, stringutil.MakeSlug(document.Folder), document.DocumentID, stringutil.MakeSlug(document.Document)))
item.Date = document.Revised.Format("2006-01-02T15:04:05.999999-07:00")
items = append(items, item)
}
}
buffer := new(bytes.Buffer)
t := template.Must(template.New("tmp").Parse(sitemap))
log.IfErr(t.Execute(buffer, &items))
response.WriteBytes(w, buffer.Bytes())
}
// sitemapItem provides a means to teleport somewhere else for free.
// What did you think it did?
type sitemapItem struct {
URL string
Date string
}

View file

@ -0,0 +1,39 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package mysql
import (
"fmt"
"github.com/documize/community/core/env"
"github.com/documize/community/domain"
"github.com/documize/community/model/page"
"github.com/pkg/errors"
)
// Scope provides data access to MySQL.
type Scope struct {
Runtime *env.Runtime
}
// GetPagesWithoutContent returns a slice containing all the page records for a given documentID, in presentation sequence,
// but without the body field (which holds the HTML content).
func (s Scope) GetPagesWithoutContent(ctx domain.RequestContext, documentID string) (pages []page.Page, err error) {
err = s.Runtime.Db.Select(&pages, "SELECT id, refid, orgid, documentid, userid, contenttype, pagetype, sequence, level, title, revisions, blockid, created, revised FROM page WHERE orgid=? AND documentid=? ORDER BY sequence", ctx.OrgID, documentID)
if err != nil {
err = errors.Wrap(err, fmt.Sprintf("Unable to execute select pages for org %s and document %s", ctx.OrgID, documentID))
return
}
return
}

248
domain/setting/endpoint.go Normal file
View file

@ -0,0 +1,248 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
// Package setting manages both global and user level settings
package setting
import (
"encoding/json"
"encoding/xml"
"io/ioutil"
"net/http"
"github.com/documize/community/core/env"
"github.com/documize/community/core/event"
"github.com/documize/community/core/response"
"github.com/documize/community/domain"
"github.com/documize/community/model/audit"
)
// Handler contains the runtime information such as logging and database.
type Handler struct {
Runtime *env.Runtime
Store *domain.Store
}
// SMTP returns installation-wide SMTP settings
func (h *Handler) SMTP(w http.ResponseWriter, r *http.Request) {
method := "setting.SMTP"
ctx := domain.GetRequestContext(r)
if !ctx.Global {
response.WriteForbiddenError(w)
return
}
config := h.Store.Setting.Get(ctx, "SMTP", "")
var y map[string]interface{}
json.Unmarshal([]byte(config), &y)
j, err := json.Marshal(y)
if err != nil {
response.WriteBadRequestError(w, method, err.Error())
return
}
response.WriteBytes(w, j)
}
// SetSMTP persists global SMTP configuration.
func (h *Handler) SetSMTP(w http.ResponseWriter, r *http.Request) {
method := "setting.SetSMTP"
ctx := domain.GetRequestContext(r)
if !ctx.Global {
response.WriteForbiddenError(w)
return
}
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
response.WriteBadRequestError(w, method, err.Error())
return
}
var config string
config = string(body)
ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil {
response.WriteServerError(w, method, err)
return
}
h.Store.Setting.Set(ctx, "SMTP", config)
h.Store.Audit.Record(ctx, audit.EventTypeSystemSMTP)
response.WriteEmpty(w)
}
// License returns product license
func (h *Handler) License(w http.ResponseWriter, r *http.Request) {
ctx := domain.GetRequestContext(r)
if !ctx.Global {
response.WriteForbiddenError(w)
return
}
config := h.Store.Setting.Get(ctx, "EDITION-LICENSE", "")
if len(config) == 0 {
config = "{}"
}
x := &licenseXML{Key: "", Signature: ""}
lj := licenseJSON{}
err := json.Unmarshal([]byte(config), &lj)
if err == nil {
x.Key = lj.Key
x.Signature = lj.Signature
} else {
h.Runtime.Log.Error("failed to JSON unmarshal EDITION-LICENSE", err)
}
output, err := xml.Marshal(x)
if err != nil {
h.Runtime.Log.Error("failed to XML marshal EDITION-LICENSE", err)
}
response.WriteBytes(w, output)
}
// SetLicense persists product license
func (h *Handler) SetLicense(w http.ResponseWriter, r *http.Request) {
method := "setting.SetLicense"
ctx := domain.GetRequestContext(r)
if !ctx.Global {
response.WriteForbiddenError(w)
return
}
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
response.WriteBadRequestError(w, method, err.Error())
return
}
var config string
config = string(body)
lj := licenseJSON{}
x := licenseXML{Key: "", Signature: ""}
err = xml.Unmarshal([]byte(config), &x)
if err == nil {
lj.Key = x.Key
lj.Signature = x.Signature
} else {
h.Runtime.Log.Error("failed to XML unmarshal EDITION-LICENSE", err)
}
j, err := json.Marshal(lj)
js := "{}"
if err == nil {
js = string(j)
}
h.Store.Setting.Set(ctx, "EDITION-LICENSE", js)
event.Handler().Publish(string(event.TypeSystemLicenseChange))
ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil {
response.WriteServerError(w, method, err)
return
}
h.Store.Audit.Record(ctx, audit.EventTypeSystemLicense)
ctx.Transaction.Commit()
response.WriteEmpty(w)
}
// AuthConfig returns installation-wide auth configuration
func (h *Handler) AuthConfig(w http.ResponseWriter, r *http.Request) {
ctx := domain.GetRequestContext(r)
if !ctx.Global {
response.WriteForbiddenError(w)
return
}
org, err := h.Store.Organization.GetOrganization(ctx, ctx.OrgID)
if err != nil {
response.WriteForbiddenError(w)
return
}
response.WriteJSON(w, org.AuthConfig)
}
// SetAuthConfig persists installation-wide authentication configuration
func (h *Handler) SetAuthConfig(w http.ResponseWriter, r *http.Request) {
method := "SaveAuthConfig"
ctx := domain.GetRequestContext(r)
if !ctx.Global {
response.WriteForbiddenError(w)
return
}
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
response.WriteBadRequestError(w, method, err.Error())
return
}
var data authData
err = json.Unmarshal(body, &data)
if err != nil {
response.WriteBadRequestError(w, method, err.Error())
return
}
org, err := h.Store.Organization.GetOrganization(ctx, ctx.OrgID)
if err != nil {
response.WriteServerError(w, method, err)
return
}
org.AuthProvider = data.AuthProvider
org.AuthConfig = data.AuthConfig
ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil {
response.WriteServerError(w, method, err)
return
}
err = h.Store.Organization.UpdateAuthConfig(ctx, org)
if err != nil {
ctx.Transaction.Rollback()
response.WriteServerError(w, method, err)
return
}
h.Store.Audit.Record(ctx, audit.EventTypeSystemAuth)
ctx.Transaction.Commit()
response.WriteEmpty(w)
}

38
domain/setting/model.go Normal file
View file

@ -0,0 +1,38 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
// Package setting manages both global and user level settings
package setting
import "encoding/xml"
type licenseXML struct {
XMLName xml.Name `xml:"DocumizeLicense"`
Key string
Signature string
}
type licenseJSON struct {
Key string `json:"key"`
Signature string `json:"signature"`
}
type authData struct {
AuthProvider string `json:"authProvider"`
AuthConfig string `json:"authConfig"`
}
/*
<DocumizeLicense>
<Key>some key</Key>
<Signature>some signature</Signature>
</DocumizeLicense>
*/

View file

@ -0,0 +1,135 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package mysql
import (
"bytes"
"fmt"
"github.com/documize/community/core/env"
"github.com/documize/community/core/streamutil"
"github.com/documize/community/domain"
"github.com/pkg/errors"
)
// Scope provides data access to MySQL.
type Scope struct {
Runtime *env.Runtime
}
// Get fetches a configuration JSON element from the config table.
func (s Scope) Get(ctx domain.RequestContext, area, path string) (value string) {
if path != "" {
path = "." + path
}
sql := "SELECT JSON_EXTRACT(`config`,'$" + path + "') FROM `config` WHERE `key` = '" + area + "';"
stmt, err := s.Runtime.Db.Preparex(sql)
defer streamutil.Close(stmt)
if err != nil {
s.Runtime.Log.Error(fmt.Sprintf("setting.Get %s %s", area, path), err)
return ""
}
var item = make([]uint8, 0)
err = stmt.Get(&item)
if err != nil {
s.Runtime.Log.Error(fmt.Sprintf("setting.Get %s %s", area, path), err)
return ""
}
if len(item) > 1 {
q := []byte(`"`)
value = string(bytes.TrimPrefix(bytes.TrimSuffix(item, q), q))
}
return value
}
// Set writes a configuration JSON element to the config table.
func (s Scope) Set(ctx domain.RequestContext, area, json string) error {
if area == "" {
return errors.New("no area")
}
sql := "INSERT INTO `config` (`key`,`config`) " +
"VALUES ('" + area + "','" + json +
"') ON DUPLICATE KEY UPDATE `config`='" + json + "';"
stmt, err := s.Runtime.Db.Preparex(sql)
defer streamutil.Close(stmt)
if err != nil {
err = errors.Wrap(err, "failed to save global config value")
return err
}
_, err = stmt.Exec()
return err
}
// GetUser fetches a configuration JSON element from the userconfig table for a given orgid/userid combination.
// Errors return the empty string. A blank path returns the whole JSON object, as JSON.
func (s Scope) GetUser(ctx domain.RequestContext, orgID, userID, area, path string) (value string) {
if path != "" {
path = "." + path
}
sql := "SELECT JSON_EXTRACT(`config`,'$" + path + "') FROM `userconfig` WHERE `key` = '" + area +
"' AND `orgid` = '" + orgID + "' AND `userid` = '" + userID + "';"
stmt, err := s.Runtime.Db.Preparex(sql)
defer streamutil.Close(stmt)
if err != nil {
return ""
}
var item = make([]uint8, 0)
err = stmt.Get(&item)
if err != nil {
s.Runtime.Log.Error(fmt.Sprintf("setting.GetUser for user %s %s %s", userID, area, path), err)
return ""
}
if len(item) > 1 {
q := []byte(`"`)
value = string(bytes.TrimPrefix(bytes.TrimSuffix(item, q), q))
}
return value
}
// SetUser writes a configuration JSON element to the userconfig table for the current user.
func (s Scope) SetUser(ctx domain.RequestContext, orgID, userID, area, json string) error {
if area == "" {
return errors.New("no area")
}
sql := "INSERT INTO `userconfig` (`orgid`,`userid`,`key`,`config`) " +
"VALUES ('" + orgID + "','" + userID + "','" + area + "','" + json +
"') ON DUPLICATE KEY UPDATE `config`='" + json + "';"
stmt, err := s.Runtime.Db.Preparex(sql)
defer streamutil.Close(stmt)
if err != nil {
return err
}
_, err = stmt.Exec()
return err
}

View file

@ -14,8 +14,12 @@ package domain
import ( import (
"github.com/documize/community/model/account" "github.com/documize/community/model/account"
"github.com/documize/community/model/attachment"
"github.com/documize/community/model/audit" "github.com/documize/community/model/audit"
"github.com/documize/community/model/doc"
"github.com/documize/community/model/link"
"github.com/documize/community/model/org" "github.com/documize/community/model/org"
"github.com/documize/community/model/page"
"github.com/documize/community/model/pin" "github.com/documize/community/model/pin"
"github.com/documize/community/model/space" "github.com/documize/community/model/space"
"github.com/documize/community/model/user" "github.com/documize/community/model/user"
@ -30,6 +34,10 @@ type Store struct {
Pin PinStorer Pin PinStorer
Audit AuditStorer Audit AuditStorer
Document DocumentStorer Document DocumentStorer
Setting SettingStorer
Attachment AttachmentStorer
Link LinkStorer
Page PageStorer
} }
// SpaceStorer defines required methods for space management // SpaceStorer defines required methods for space management
@ -112,7 +120,43 @@ type AuditStorer interface {
// DocumentStorer defines required methods for document handling // DocumentStorer defines required methods for document handling
type DocumentStorer interface { type DocumentStorer interface {
Get(ctx RequestContext, id string) (document doc.Document, err error)
MoveDocumentSpace(ctx RequestContext, id, move string) (err error) MoveDocumentSpace(ctx RequestContext, id, move string) (err error)
PublicDocuments(ctx RequestContext, orgID string) (documents []doc.SitemapDocument, err error)
} }
// https://github.com/golang-sql/sqlexp/blob/c2488a8be21d20d31abf0d05c2735efd2d09afe4/quoter.go#L46 // SettingStorer defines required methods for persisting global and user level settings
type SettingStorer interface {
Get(ctx RequestContext, area, path string) string
Set(ctx RequestContext, area, value string) error
GetUser(ctx RequestContext, orgID, userID, area, path string) string
SetUser(ctx RequestContext, orgID, userID, area, json string) error
}
// AttachmentStorer defines required methods for persisting document attachments
type AttachmentStorer interface {
Add(ctx RequestContext, a attachment.Attachment) (err error)
GetAttachment(ctx RequestContext, orgID, attachmentID string) (a attachment.Attachment, err error)
GetAttachments(ctx RequestContext, docID string) (a []attachment.Attachment, err error)
GetAttachmentsWithData(ctx RequestContext, docID string) (a []attachment.Attachment, err error)
Delete(ctx RequestContext, id string) (rows int64, err error)
}
// LinkStorer defines required methods for persisting content links
type LinkStorer interface {
Add(ctx RequestContext, l link.Link) (err error)
SearchCandidates(ctx RequestContext, keywords string) (docs []link.Candidate, pages []link.Candidate, attachments []link.Candidate, err error)
GetDocumentOutboundLinks(ctx RequestContext, documentID string) (links []link.Link, err error)
GetPageLinks(ctx RequestContext, documentID, pageID string) (links []link.Link, err error)
MarkOrphanDocumentLink(ctx RequestContext, documentID string) (err error)
MarkOrphanPageLink(ctx RequestContext, pageID string) (err error)
MarkOrphanAttachmentLink(ctx RequestContext, attachmentID string) (err error)
DeleteSourcePageLinks(ctx RequestContext, pageID string) (rows int64, err error)
DeleteSourceDocumentLinks(ctx RequestContext, documentID string) (rows int64, err error)
DeleteLink(ctx RequestContext, id string) (rows int64, err error)
}
// PageStorer defines required methods for persisting document pages
type PageStorer interface {
GetPagesWithoutContent(ctx RequestContext, documentID string) (pages []page.Page, err error)
}

View file

View file

@ -12,27 +12,48 @@
package user package user
import ( import (
"database/sql"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"strconv"
"github.com/documize/community/core/api/mail"
"github.com/documize/community/core/env" "github.com/documize/community/core/env"
"github.com/documize/community/core/event"
"github.com/documize/community/core/request"
"github.com/documize/community/core/response"
"github.com/documize/community/core/secrets"
"github.com/documize/community/core/streamutil"
"github.com/documize/community/core/stringutil"
"github.com/documize/community/core/uniqueid"
"github.com/documize/community/domain" "github.com/documize/community/domain"
"github.com/documize/community/model/account"
"github.com/documize/community/model/audit"
"github.com/documize/community/model/space"
"github.com/documize/community/model/user"
) )
// Handler contains the runtime information such as logging and database. // Handler contains the runtime information such as logging and database.
type Handler struct { type Handler struct {
Runtime *env.Runtime Runtime *env.Runtime
Store domain.Store Store *domain.Store
} }
/* // Add is the endpoint that enables an administrator to add a new user for their orgaisation.
// AddUser is the endpoint that enables an administrator to add a new user for their orgaisation. func (h *Handler) Add(w http.ResponseWriter, r *http.Request) {
func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { method := "user.Add"
method := "user.AddUser"
ctx := domain.GetRequestContext(r) ctx := domain.GetRequestContext(r)
if !h.Runtime.Product.License.IsValid() { if !h.Runtime.Product.License.IsValid() {
response.WriteBadLicense(w) response.WriteBadLicense(w)
} }
if !s.Context.Administrator { if !ctx.Administrator {
response.WriteForbiddenError(w) response.WriteForbiddenError(w)
return return
} }
@ -44,7 +65,7 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) {
return return
} }
userModel := model.User{} userModel := user.User{}
err = json.Unmarshal(body, &userModel) err = json.Unmarshal(body, &userModel)
if err != nil { if err != nil {
response.WriteBadRequestError(w, method, err.Error()) response.WriteBadRequestError(w, method, err.Error())
@ -82,7 +103,7 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) {
addAccount := true addAccount := true
var userID string var userID string
userDupe, err := h.Store.User.GetByEmail(s, userModel.Email) userDupe, err := h.Store.User.GetByEmail(ctx, userModel.Email)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
@ -95,7 +116,7 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) {
h.Runtime.Log.Info("Dupe user found, will not add " + userModel.Email) h.Runtime.Log.Info("Dupe user found, will not add " + userModel.Email)
} }
s.Context.Transaction, err = request.Db.Beginx() ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
@ -105,19 +126,19 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) {
userID = uniqueid.Generate() userID = uniqueid.Generate()
userModel.RefID = userID userModel.RefID = userID
err = h.Store.User.Add(s, userModel) err = h.Store.User.Add(ctx, userModel)
if err != nil { if err != nil {
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
h.Runtime.Log.Info("Adding user") h.Runtime.Log.Info("Adding user")
} else { } else {
AttachUserAccounts(s, s.Context.OrgID, &userDupe) AttachUserAccounts(ctx, *h.Store, ctx.OrgID, &userDupe)
for _, a := range userDupe.Accounts { for _, a := range userDupe.Accounts {
if a.OrgID == s.Context.OrgID { if a.OrgID == ctx.OrgID {
addAccount = false addAccount = false
h.Runtime.Log.Info("Dupe account found, will not add") h.Runtime.Log.Info("Dupe account found, will not add")
break break
@ -127,17 +148,17 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) {
// set up user account for the org // set up user account for the org
if addAccount { if addAccount {
var a model.Account var a account.Account
a.RefID = uniqueid.Generate() a.RefID = uniqueid.Generate()
a.UserID = userID a.UserID = userID
a.OrgID = s.Context.OrgID a.OrgID = ctx.OrgID
a.Editor = true a.Editor = true
a.Admin = false a.Admin = false
a.Active = true a.Active = true
err = account.Add(s, a) err = h.Store.Account.Add(ctx, a)
if err != nil { if err != nil {
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
@ -145,15 +166,15 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) {
if addUser { if addUser {
event.Handler().Publish(string(event.TypeAddUser)) event.Handler().Publish(string(event.TypeAddUser))
eventing.Record(s, eventing.EventTypeUserAdd) h.Store.Audit.Record(ctx, audit.EventTypeUserAdd)
} }
if addAccount { if addAccount {
event.Handler().Publish(string(event.TypeAddAccount)) event.Handler().Publish(string(event.TypeAddAccount))
eventing.Record(s, eventing.EventTypeAccountAdd) h.Store.Audit.Record(ctx, audit.EventTypeAccountAdd)
} }
s.Context.Transaction.Commit() ctx.Transaction.Commit()
// If we did not add user or give them access (account) then we error back // If we did not add user or give them access (account) then we error back
if !addUser && !addAccount { if !addUser && !addAccount {
@ -162,7 +183,7 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) {
} }
// Invite new user // Invite new user
inviter, err := h.Store.User.Get(s, s.Context.UserID) inviter, err := h.Store.User.Get(ctx, ctx.UserID)
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
@ -172,16 +193,16 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) {
if addUser && addAccount { if addUser && addAccount {
size := len(requestedPassword) size := len(requestedPassword)
auth := fmt.Sprintf("%s:%s:%s", s.Context.AppURL, userModel.Email, requestedPassword[:size]) auth := fmt.Sprintf("%s:%s:%s", ctx.AppURL, userModel.Email, requestedPassword[:size])
encrypted := secrets.EncodeBase64([]byte(auth)) encrypted := secrets.EncodeBase64([]byte(auth))
url := fmt.Sprintf("%s/%s", s.Context.GetAppURL("auth/sso"), url.QueryEscape(string(encrypted))) url := fmt.Sprintf("%s/%s", ctx.GetAppURL("auth/sso"), url.QueryEscape(string(encrypted)))
go mail.InviteNewUser(userModel.Email, inviter.Fullname(), url, userModel.Email, requestedPassword) go mail.InviteNewUser(userModel.Email, inviter.Fullname(), url, userModel.Email, requestedPassword)
h.Runtime.Log.Info(fmt.Sprintf("%s invited by %s on %s", userModel.Email, inviter.Email, s.Context.AppURL)) h.Runtime.Log.Info(fmt.Sprintf("%s invited by %s on %s", userModel.Email, inviter.Email, ctx.AppURL))
} else { } else {
go mail.InviteExistingUser(userModel.Email, inviter.Fullname(), s.Context.GetAppURL("")) go mail.InviteExistingUser(userModel.Email, inviter.Fullname(), ctx.GetAppURL(""))
h.Runtime.Log.Info(fmt.Sprintf("%s is giving access to an existing user %s", inviter.Email, userModel.Email)) h.Runtime.Log.Info(fmt.Sprintf("%s is giving access to an existing user %s", inviter.Email, userModel.Email))
} }
@ -189,33 +210,32 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) {
response.WriteJSON(w, userModel) response.WriteJSON(w, userModel)
} }
/*
// GetOrganizationUsers is the endpoint that allows administrators to view the users in their organisation. // GetOrganizationUsers is the endpoint that allows administrators to view the users in their organisation.
func (h *Handler) GetOrganizationUsers(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetOrganizationUsers(w http.ResponseWriter, r *http.Request) {
method := "pin.GetUserPins" method := "user.GetOrganizationUsers"
s := domain.NewContext(h.Runtime, r) ctx := domain.GetRequestContext(r)
if !s.Context.Editor && !s.Context.Administrator { if !ctx.Administrator {
response.WriteForbiddenError(w) response.WriteForbiddenError(w)
return return
} }
active, err := strconv.ParseBool(request.Query("active")) active, err := strconv.ParseBool(request.Query(r, "active"))
if err != nil { if err != nil {
active = false active = false
} }
u := []User{} u := []user.User{}
if active { if active {
u, err = GetActiveUsersForOrganization(s) u, err = h.Store.User.GetActiveUsersForOrganization(ctx)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
} else { } else {
u, err = GetUsersForOrganization(s) u, err = h.Store.User.GetUsersForOrganization(ctx)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
@ -223,11 +243,11 @@ func (h *Handler) GetOrganizationUsers(w http.ResponseWriter, r *http.Request) {
} }
if len(u) == 0 { if len(u) == 0 {
u = []User{} u = []user.User{}
} }
for i := range u { for i := range u {
AttachUserAccounts(s, s.Context.OrgID, &u[i]) AttachUserAccounts(ctx, *h.Store, ctx.OrgID, &u[i])
} }
response.WriteJSON(w, u) response.WriteJSON(w, u)
@ -236,19 +256,19 @@ func (h *Handler) GetOrganizationUsers(w http.ResponseWriter, r *http.Request) {
// GetSpaceUsers returns every user within a given space // GetSpaceUsers returns every user within a given space
func (h *Handler) GetSpaceUsers(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetSpaceUsers(w http.ResponseWriter, r *http.Request) {
method := "user.GetSpaceUsers" method := "user.GetSpaceUsers"
s := domain.NewContext(h.Runtime, r) ctx := domain.GetRequestContext(r)
var u []User var u []user.User
var err error var err error
folderID := request.Param("folderID") folderID := request.Param(r, "folderID")
if len(folderID) == 0 { if len(folderID) == 0 {
response.WriteMissingDataError(w, method, "folderID") response.WriteMissingDataError(w, method, "folderID")
return return
} }
// check to see space type as it determines user selection criteria // check to see space type as it determines user selection criteria
folder, err := space.Get(s, folderID) folder, err := h.Store.Space.Get(ctx, folderID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
h.Runtime.Log.Error("cannot get space", err) h.Runtime.Log.Error("cannot get space", err)
response.WriteJSON(w, u) response.WriteJSON(w, u)
@ -256,22 +276,22 @@ func (h *Handler) GetSpaceUsers(w http.ResponseWriter, r *http.Request) {
} }
switch folder.Type { switch folder.Type {
case entity.FolderTypePublic: case space.ScopePublic:
u, err = GetActiveUsersForOrganization(s) u, err = h.Store.User.GetActiveUsersForOrganization(ctx)
break break
case entity.FolderTypePrivate: case space.ScopePrivate:
// just me // just me
var me User var me user.User
user, err = Get(s, s.Context.UserID) me, err = h.Store.User.Get(ctx, ctx.UserID)
u = append(u, me) u = append(u, me)
break break
case entity.FolderTypeRestricted: case space.ScopeRestricted:
u, err = GetSpaceUsers(s, folderID) u, err = h.Store.User.GetSpaceUsers(ctx, folderID)
break break
} }
if len(u) == 0 { if len(u) == 0 {
u = []User u = []user.User{}
} }
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -283,25 +303,25 @@ func (h *Handler) GetSpaceUsers(w http.ResponseWriter, r *http.Request) {
response.WriteJSON(w, u) response.WriteJSON(w, u)
} }
// GetUser returns user specified by ID // Get returns user specified by ID
func (h *Handler) GetUser(w http.ResponseWriter, r *http.Request) { func (h *Handler) Get(w http.ResponseWriter, r *http.Request) {
method := "user.GetUser" method := "user.Get"
s := domain.NewContext(h.Runtime, r) ctx := domain.GetRequestContext(r)
userID := request.Param("userID") userID := request.Param(r, "userID")
if len(userID) == 0 { if len(userID) == 0 {
response.WriteMissingDataError(w, method, "userId") response.WriteMissingDataError(w, method, "userId")
return return
} }
if userID != s.Context.UserID { if userID != ctx.UserID {
response.WriteBadRequestError(w, method, "userId mismatch") response.WriteBadRequestError(w, method, "userId mismatch")
return return
} }
u, err := GetSecuredUser(s, s.Context.OrgID, userID) u, err := GetSecuredUser(ctx, *h.Store, ctx.OrgID, userID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
response.WriteNotFoundError(s, method, s.Context.UserID) response.WriteNotFoundError(w, method, ctx.UserID)
return return
} }
if err != nil { if err != nil {
@ -309,66 +329,63 @@ func (h *Handler) GetUser(w http.ResponseWriter, r *http.Request) {
return return
} }
response.WriteJSON(u) response.WriteJSON(w, u)
}
// DeleteUser is the endpoint to delete a user specified by userID, the caller must be an Administrator. // Delete is the endpoint to delete a user specified by userID, the caller must be an Administrator.
func (h *Handler) DeleteUser(w http.ResponseWriter, r *http.Request) { func (h *Handler) Delete(w http.ResponseWriter, r *http.Request) {
method := "user.DeleteUser" method := "user.Delete"
s := domain.NewContext(h.Runtime, r) ctx := domain.GetRequestContext(r)
if !s.Context.Administrator { userID := request.Param(r, "userID")
response.WriteForbiddenError(w)
return
}
userID := response.Params("userID")
if len(userID) == 0 { if len(userID) == 0 {
response.WriteMissingDataError(w, method, "userID") response.WriteMissingDataError(w, method, "userId")
return return
} }
if userID == s.Context.UserID { if userID == ctx.UserID {
response.WriteBadRequestError(w, method, "cannot delete self") response.WriteBadRequestError(w, method, "cannot delete self")
return return
} }
var err error var err error
s.Context.Transaction, err = h.Runtime.Db.Beginx() ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
err = DeactiveUser(s, userID) err = h.Store.User.DeactiveUser(ctx, userID)
if err != nil { if err != nil {
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
err = space.ChangeLabelOwner(s, userID, s.Context.UserID) err = h.Store.Space.ChangeOwner(ctx, userID, ctx.UserID)
if err != nil { if err != nil {
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
eventing.Record(s, eventing.EventTypeUserDelete) h.Store.Audit.Record(ctx, audit.EventTypeUserDelete)
event.Handler().Publish(string(event.TypeRemoveUser)) event.Handler().Publish(string(event.TypeRemoveUser))
s.Context.Transaction.Commit() ctx.Transaction.Commit()
response.WriteEmpty() response.WriteEmpty(w)
} }
// UpdateUser is the endpoint to update user information for the given userID. // Update is the endpoint to update user information for the given userID.
// Note that unless they have admin privildges, a user can only update their own information. // Note that unless they have admin privildges, a user can only update their own information.
// Also, only admins can update user roles in organisations. // Also, only admins can update user roles in organisations.
func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { func (h *Handler) Update(w http.ResponseWriter, r *http.Request) {
method := "user.DeleteUser" method := "user.Update"
s := domain.NewContext(h.Runtime, r) ctx := domain.GetRequestContext(r)
userID := request.Param("userID") userID := request.Param(r, "userID")
if len(userID) == 0 { if len(userID) == 0 {
response.WriteBadRequestError(w, method, "user id must be numeric") response.WriteBadRequestError(w, method, "user id must be numeric")
return return
@ -377,11 +394,11 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) {
defer streamutil.Close(r.Body) defer streamutil.Close(r.Body)
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
response.WritePayloadError(w, method, err) response.WriteBadRequestError(w, method, err.Error())
return return
} }
u := User{} u := user.User{}
err = json.Unmarshal(body, &u) err = json.Unmarshal(body, &u)
if err != nil { if err != nil {
response.WriteBadRequestError(w, method, err.Error()) response.WriteBadRequestError(w, method, err.Error())
@ -389,7 +406,7 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) {
} }
// can only update your own account unless you are an admin // can only update your own account unless you are an admin
if s.Context.UserID != userID && !s.Context.Administrator { if ctx.UserID != userID && !ctx.Administrator {
response.WriteForbiddenError(w) response.WriteForbiddenError(w)
return return
} }
@ -400,7 +417,7 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
s.Context.Transaction, err = h.Runtime.Db.Beginx() ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
@ -409,9 +426,9 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) {
u.RefID = userID u.RefID = userID
u.Initials = stringutil.MakeInitials(u.Firstname, u.Lastname) u.Initials = stringutil.MakeInitials(u.Firstname, u.Lastname)
err = UpdateUser(s, u) err = h.Store.User.UpdateUser(ctx, u)
if err != nil { if err != nil {
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
@ -419,9 +436,9 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) {
// Now we update user roles for this organization. // Now we update user roles for this organization.
// That means we have to first find their account record // That means we have to first find their account record
// for this organization. // for this organization.
a, err := account.GetUserAccount(s, userID) a, err := h.Store.Account.GetUserAccount(ctx, userID)
if err != nil { if err != nil {
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
@ -430,26 +447,26 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) {
a.Admin = u.Admin a.Admin = u.Admin
a.Active = u.Active a.Active = u.Active
err = account.UpdateAccount(s, account) err = h.Store.Account.UpdateAccount(ctx, a)
if err != nil { if err != nil {
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
eventing.Record(s, eventing.EventTypeUserUpdate) h.Store.Audit.Record(ctx, audit.EventTypeUserUpdate)
s.Context.Transaction.Commit() ctx.Transaction.Commit()
response.WriteJSON(u) response.WriteEmpty(w)
} }
// ChangeUserPassword accepts password change from within the app. // ChangePassword accepts password change from within the app.
func (h *Handler) ChangeUserPassword(w http.ResponseWriter, r *http.Request) { func (h *Handler) ChangePassword(w http.ResponseWriter, r *http.Request) {
method := "user.ChangeUserPassword" method := "user.ChangePassword"
s := domain.NewContext(h.Runtime, r) ctx := domain.GetRequestContext(r)
userID := response.Param("userID") userID := request.Param(r, "userID")
if len(userID) == 0 { if len(userID) == 0 {
response.WriteMissingDataError(w, method, "user id") response.WriteMissingDataError(w, method, "user id")
return return
@ -464,18 +481,18 @@ func (h *Handler) ChangeUserPassword(w http.ResponseWriter, r *http.Request) {
newPassword := string(body) newPassword := string(body)
// can only update your own account unless you are an admin // can only update your own account unless you are an admin
if userID != s.Context.UserID && !s.Context.Administrator { if userID != ctx.UserID || !ctx.Administrator {
response.WriteForbiddenError(w) response.WriteForbiddenError(w)
return return
} }
s.Context.Transaction, err = h.Runtime.Db.Beginx() ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
u, err := Get(s, userID) u, err := h.Store.User.Get(ctx, userID)
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
@ -483,28 +500,29 @@ func (h *Handler) ChangeUserPassword(w http.ResponseWriter, r *http.Request) {
u.Salt = secrets.GenerateSalt() u.Salt = secrets.GenerateSalt()
err = UpdateUserPassword(s, userID, user.Salt, secrets.GeneratePassword(newPassword, user.Salt)) err = h.Store.User.UpdateUserPassword(ctx, userID, u.Salt, secrets.GeneratePassword(newPassword, u.Salt))
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteEmpty(w) response.WriteEmpty(w)
} }
// GetUserFolderPermissions returns folder permission for authenticated user. // UserSpacePermissions returns folder permission for authenticated user.
func (h *Handler) GetUserFolderPermissions(w http.ResponseWriter, r *http.Request) { func (h *Handler) UserSpacePermissions(w http.ResponseWriter, r *http.Request) {
method := "user.ChangeUserPassword" method := "user.UserSpacePermissions"
s := domain.NewContext(h.Runtime, r) ctx := domain.GetRequestContext(r)
userID := request.Param("userID") userID := request.Param(r, "userID")
if userID != p.Context.UserID { if userID != ctx.UserID {
response.WriteForbiddenError(w) response.WriteForbiddenError(w)
return return
} }
roles, err := space.GetUserLabelRoles(s, userID) roles, err := h.Store.Space.GetUserRoles(ctx)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
roles = []space.Role{} roles = []space.Role{}
@ -517,12 +535,12 @@ func (h *Handler) GetUserFolderPermissions(w http.ResponseWriter, r *http.Reques
response.WriteJSON(w, roles) response.WriteJSON(w, roles)
} }
// ForgotUserPassword initiates the change password procedure. // ForgotPassword initiates the change password procedure.
// Generates a reset token and sends email to the user. // Generates a reset token and sends email to the user.
// User has to click link in email and then provide a new password. // User has to click link in email and then provide a new password.
func (h *Handler) ForgotUserPassword(w http.ResponseWriter, r *http.Request) { func (h *Handler) ForgotPassword(w http.ResponseWriter, r *http.Request) {
method := "user.ForgotUserPassword" method := "user.ForgotPassword"
s := domain.NewContext(h.Runtime, r) ctx := domain.GetRequestContext(r)
defer streamutil.Close(r.Body) defer streamutil.Close(r.Body)
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
@ -531,14 +549,14 @@ func (h *Handler) ForgotUserPassword(w http.ResponseWriter, r *http.Request) {
return return
} }
u := new(User) u := new(user.User)
err = json.Unmarshal(body, &u) err = json.Unmarshal(body, &u)
if err != nil { if err != nil {
response.WriteBadRequestError(w, method, "JSON body") response.WriteBadRequestError(w, method, "JSON body")
return return
} }
s.Context.Transaction, err = request.Db.Beginx() ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
@ -546,33 +564,33 @@ func (h *Handler) ForgotUserPassword(w http.ResponseWriter, r *http.Request) {
token := secrets.GenerateSalt() token := secrets.GenerateSalt()
err = ForgotUserPassword(s, u.Email, token) err = h.Store.User.ForgotUserPassword(ctx, u.Email, token)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
response.WriteEmpty(w) response.WriteEmpty(w)
h.Runtime.Log.Info(fmt.Errorf("User %s not found for password reset process", u.Email)) h.Runtime.Log.Info(fmt.Sprintf("User %s not found for password reset process", u.Email))
return return
} }
s.Context.Transaction.Commit() ctx.Transaction.Commit()
appURL := s.Context.GetAppURL(fmt.Sprintf("auth/reset/%s", token)) appURL := ctx.GetAppURL(fmt.Sprintf("auth/reset/%s", token))
go mail.PasswordReset(u.Email, appURL) go mail.PasswordReset(u.Email, appURL)
response.WriteEmpty(w) response.WriteEmpty(w)
} }
// ResetUserPassword stores the newly chosen password for the user. // ResetPassword stores the newly chosen password for the user.
func (h *Handler) ResetUserPassword(w http.ResponseWriter, r *http.Request) { func (h *Handler) ResetPassword(w http.ResponseWriter, r *http.Request) {
method := "user.ForgotUserPassword" method := "user.ForgotUserPassword"
s := domain.NewContext(h.Runtime, r) ctx := domain.GetRequestContext(r)
token := request.Param("token") token := request.Param(r, "token")
if len(token) == 0 { if len(token) == 0 {
response.WriteMissingDataError(w, method, "missing token") response.WriteMissingDataError(w, method, "missing token")
return return
@ -586,31 +604,30 @@ func (h *Handler) ResetUserPassword(w http.ResponseWriter, r *http.Request) {
} }
newPassword := string(body) newPassword := string(body)
s.Context.Transaction, err = h.Runtime.Db.Beginx() ctx.Transaction, err = h.Runtime.Db.Beginx()
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
u, err := GetByToken(token) u, err := h.Store.User.GetByToken(ctx, token)
if err != nil { if err != nil {
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
user.Salt = secrets.GenerateSalt() u.Salt = secrets.GenerateSalt()
err = UpdateUserPassword(s, u.RefID, u.Salt, secrets.GeneratePassword(newPassword, u.Salt)) err = h.Store.User.UpdateUserPassword(ctx, u.RefID, u.Salt, secrets.GeneratePassword(newPassword, u.Salt))
if err != nil { if err != nil {
s.Context.Transaction.Rollback() ctx.Transaction.Rollback()
response.WriteServerError(w, method, err) response.WriteServerError(w, method, err)
return return
} }
eventing.Record(s, eventing.EventTypeUserPasswordReset) h.Store.Audit.Record(ctx, audit.EventTypeUserPasswordReset)
s.Context.Transaction.Commit() ctx.Transaction.Commit()
response.WriteEmpty(w) response.WriteEmpty(w)
} }
*/

View file

@ -16,12 +16,16 @@ import (
"github.com/documize/community/core/env" "github.com/documize/community/core/env"
"github.com/documize/community/domain" "github.com/documize/community/domain"
account "github.com/documize/community/domain/account/mysql" account "github.com/documize/community/domain/account/mysql"
attachment "github.com/documize/community/domain/attachment/mysql"
audit "github.com/documize/community/domain/audit/mysql" audit "github.com/documize/community/domain/audit/mysql"
doc "github.com/documize/community/domain/document/mysql"
link "github.com/documize/community/domain/link/mysql"
org "github.com/documize/community/domain/organization/mysql" org "github.com/documize/community/domain/organization/mysql"
page "github.com/documize/community/domain/page/mysql"
pin "github.com/documize/community/domain/pin/mysql" pin "github.com/documize/community/domain/pin/mysql"
setting "github.com/documize/community/domain/setting/mysql"
space "github.com/documize/community/domain/space/mysql" space "github.com/documize/community/domain/space/mysql"
user "github.com/documize/community/domain/user/mysql" user "github.com/documize/community/domain/user/mysql"
doc "github.com/documize/community/domain/document/mysql"
) )
// AttachStore selects database persistence layer // AttachStore selects database persistence layer
@ -33,4 +37,10 @@ func AttachStore(r *env.Runtime, s *domain.Store) {
s.Pin = pin.Scope{Runtime: r} s.Pin = pin.Scope{Runtime: r}
s.Audit = audit.Scope{Runtime: r} s.Audit = audit.Scope{Runtime: r}
s.Document = doc.Scope{Runtime: r} s.Document = doc.Scope{Runtime: r}
s.Setting = setting.Scope{Runtime: r}
s.Attachment = attachment.Scope{Runtime: r}
s.Link = link.Scope{Runtime: r}
s.Page = page.Scope{Runtime: r}
} }
// https://github.com/golang-sql/sqlexp/blob/c2488a8be21d20d31abf0d05c2735efd2d09afe4/quoter.go#L46

View file

@ -0,0 +1,73 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package activity
import "time"
// UserActivity represents an activity undertaken by a user.
type UserActivity struct {
ID uint64 `json:"-"`
OrgID string `json:"orgId"`
UserID string `json:"userId"`
LabelID string `json:"folderId"`
SourceID string `json:"sourceId"`
SourceName string `json:"sourceName"` // e.g. Document or Space name
SourceType SourceType `json:"sourceType"`
ActivityType Type `json:"activityType"`
Created time.Time `json:"created"`
}
// SourceType details where the activity occured.
type SourceType int
// Type determines type of user activity
type Type int
const (
// SourceTypeSpace indicates activity against a space.
SourceTypeSpace SourceType = 1
// SourceTypeDocument indicates activity against a document.
SourceTypeDocument SourceType = 2
)
const (
// TypeCreated records user document creation
TypeCreated Type = 1
// TypeRead states user has read document
TypeRead Type = 2
// TypeEdited states user has editing document
TypeEdited Type = 3
// TypeDeleted records user deleting space/document
TypeDeleted Type = 4
// TypeArchived records user archiving space/document
TypeArchived Type = 5
// TypeApproved records user approval of document
TypeApproved Type = 6
// TypeReverted records user content roll-back to previous version
TypeReverted Type = 7
// TypePublishedTemplate records user creating new document template
TypePublishedTemplate Type = 8
// TypePublishedBlock records user creating reusable content block
TypePublishedBlock Type = 9
// TypeFeedback records user providing document feedback
TypeFeedback Type = 10
)

View file

@ -0,0 +1,26 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package attachment
import "github.com/documize/community/model"
// Attachment represents an attachment to a document.
type Attachment struct {
model.BaseEntity
OrgID string `json:"orgId"`
DocumentID string `json:"documentId"`
Job string `json:"job"`
FileID string `json:"fileId"`
Filename string `json:"filename"`
Data []byte `json:"-"`
Extension string `json:"extension"`
}

View file

@ -11,15 +11,10 @@
package auth package auth
/* import "github.com/documize/community/model/user"
// Handler contains the runtime information such as logging and database.
type Handler struct {
Runtime *env.Runtime
}
// AuthenticationModel details authentication token and user details. // AuthenticationModel details authentication token and user details.
type AuthenticationModel struct { type AuthenticationModel struct {
Token string `json:"token"` Token string `json:"token"`
User user.User `json:"user"` User user.User `json:"user"`
} }
*/

52
model/auth/keycloak.go Normal file
View file

@ -0,0 +1,52 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package auth
// KeycloakAuthRequest data received via Keycloak client library
type KeycloakAuthRequest struct {
Domain string `json:"domain"`
Token string `json:"token"`
RemoteID string `json:"remoteId"`
Email string `json:"email"`
Username string `json:"username"`
Firstname string `json:"firstname"`
Lastname string `json:"lastname"`
Enabled bool `json:"enabled"`
}
// KeycloakConfig server configuration
type KeycloakConfig struct {
URL string `json:"url"`
Realm string `json:"realm"`
ClientID string `json:"clientId"`
PublicKey string `json:"publicKey"`
AdminUser string `json:"adminUser"`
AdminPassword string `json:"adminPassword"`
Group string `json:"group"`
DisableLogout bool `json:"disableLogout"`
DefaultPermissionAddSpace bool `json:"defaultPermissionAddSpace"`
}
// KeycloakAPIAuth is returned when authenticating with Keycloak REST API.
type KeycloakAPIAuth struct {
AccessToken string `json:"access_token"`
}
// KeycloakUser details user record returned by Keycloak
type KeycloakUser struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Firstname string `json:"firstName"`
Lastname string `json:"lastName"`
Enabled bool `json:"enabled"`
}

82
model/doc/doc.go Normal file
View file

@ -0,0 +1,82 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package doc
import (
"strings"
"time"
"github.com/documize/community/model"
)
// Document represents the purpose of Documize.
type Document struct {
model.BaseEntity
OrgID string `json:"orgId"`
LabelID string `json:"folderId"`
UserID string `json:"userId"`
Job string `json:"job"`
Location string `json:"location"`
Title string `json:"name"`
Excerpt string `json:"excerpt"`
Slug string `json:"-"`
Tags string `json:"tags"`
Template bool `json:"template"`
Layout string `json:"layout"`
}
// SetDefaults ensures on blanks and cleans.
func (d *Document) SetDefaults() {
d.Title = strings.TrimSpace(d.Title)
if len(d.Title) == 0 {
d.Title = "Document"
}
}
// DocumentMeta details who viewed the document.
type DocumentMeta struct {
Viewers []DocumentMetaViewer `json:"viewers"`
Editors []DocumentMetaEditor `json:"editors"`
}
// DocumentMetaViewer contains the "view" metatdata content.
type DocumentMetaViewer struct {
UserID string `json:"userId"`
Created time.Time `json:"created"`
Firstname string `json:"firstname"`
Lastname string `json:"lastname"`
}
// DocumentMetaEditor contains the "edit" metatdata content.
type DocumentMetaEditor struct {
PageID string `json:"pageId"`
UserID string `json:"userId"`
Action string `json:"action"`
Created time.Time `json:"created"`
Firstname string `json:"firstname"`
Lastname string `json:"lastname"`
}
// UploadModel details the job ID of an uploaded document.
type UploadModel struct {
JobID string `json:"jobId"`
}
// SitemapDocument details a document that can be exposed via Sitemap.
type SitemapDocument struct {
DocumentID string
Document string
FolderID string
Folder string
Revised time.Time
}

39
model/link/link.go Normal file
View file

@ -0,0 +1,39 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package link
import "github.com/documize/community/model"
// Link defines a reference between a section and another document/section/attachment.
type Link struct {
model.BaseEntity
OrgID string `json:"orgId"`
FolderID string `json:"folderId"`
UserID string `json:"userId"`
LinkType string `json:"linkType"`
SourceDocumentID string `json:"sourceDocumentId"`
SourcePageID string `json:"sourcePageId"`
TargetDocumentID string `json:"targetDocumentId"`
TargetID string `json:"targetId"`
Orphan bool `json:"orphan"`
}
// Candidate defines a potential link to a document/section/attachment.
type Candidate struct {
RefID string `json:"id"`
LinkType string `json:"linkType"`
FolderID string `json:"folderId"`
DocumentID string `json:"documentId"`
TargetID string `json:"targetId"`
Title string `json:"title"` // what we label the link
Context string `json:"context"` // additional context (e.g. excerpt, parent, file extension)
}

38
model/org/meta.go Normal file
View file

@ -0,0 +1,38 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package org
import "time"
// SitemapDocument details a document that can be exposed via Sitemap.
type SitemapDocument struct {
DocumentID string
Document string
FolderID string
Folder string
Revised time.Time
}
// SiteMeta holds information associated with an Organization.
type SiteMeta struct {
OrgID string `json:"orgId"`
Title string `json:"title"`
Message string `json:"message"`
URL string `json:"url"`
AllowAnonymousAccess bool `json:"allowAnonymousAccess"`
AuthProvider string `json:"authProvider"`
AuthConfig string `json:"authConfig"`
Version string `json:"version"`
Edition string `json:"edition"`
Valid bool `json:"valid"`
ConversionEndpoint string `json:"conversionEndpoint"`
}

116
model/page/page.go Normal file
View file

@ -0,0 +1,116 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package page
import (
"strings"
"time"
"github.com/documize/community/model"
)
// Page represents a section within a document.
type Page struct {
model.BaseEntity
OrgID string `json:"orgId"`
DocumentID string `json:"documentId"`
UserID string `json:"userId"`
ContentType string `json:"contentType"`
PageType string `json:"pageType"`
BlockID string `json:"blockId"`
Level uint64 `json:"level"`
Sequence float64 `json:"sequence"`
Title string `json:"title"`
Body string `json:"body"`
Revisions uint64 `json:"revisions"`
}
// SetDefaults ensures no blank values.
func (p *Page) SetDefaults() {
if len(p.ContentType) == 0 {
p.ContentType = "wysiwyg"
}
p.Title = strings.TrimSpace(p.Title)
}
// IsSectionType tells us that page is "words"
func (p *Page) IsSectionType() bool {
return p.PageType == "section"
}
// IsTabType tells us that page is "SaaS data embed"
func (p *Page) IsTabType() bool {
return p.PageType == "tab"
}
// Meta holds raw page data that is used to
// render the actual page data.
type Meta struct {
ID uint64 `json:"id"`
Created time.Time `json:"created"`
Revised time.Time `json:"revised"`
OrgID string `json:"orgId"`
UserID string `json:"userId"`
DocumentID string `json:"documentId"`
PageID string `json:"pageId"`
RawBody string `json:"rawBody"` // a blob of data
Config string `json:"config"` // JSON based custom config for this type
ExternalSource bool `json:"externalSource"` // true indicates data sourced externally
}
// SetDefaults ensures no blank values.
func (p *Meta) SetDefaults() {
if len(p.Config) == 0 {
p.Config = "{}"
}
}
// Revision holds the previous version of a Page.
type Revision struct {
model.BaseEntity
OrgID string `json:"orgId"`
DocumentID string `json:"documentId"`
PageID string `json:"pageId"`
OwnerID string `json:"ownerId"`
UserID string `json:"userId"`
ContentType string `json:"contentType"`
PageType string `json:"pageType"`
Title string `json:"title"`
Body string `json:"body"`
RawBody string `json:"rawBody"`
Config string `json:"config"`
Email string `json:"email"`
Firstname string `json:"firstname"`
Lastname string `json:"lastname"`
Initials string `json:"initials"`
Revisions int `json:"revisions"`
}
// Block represents a section that has been published as a reusable content block.
type Block struct {
model.BaseEntity
OrgID string `json:"orgId"`
LabelID string `json:"folderId"`
UserID string `json:"userId"`
ContentType string `json:"contentType"`
PageType string `json:"pageType"`
Title string `json:"title"`
Body string `json:"body"`
Excerpt string `json:"excerpt"`
RawBody string `json:"rawBody"` // a blob of data
Config string `json:"config"` // JSON based custom config for this type
ExternalSource bool `json:"externalSource"` // true indicates data sourced externally
Used uint64 `json:"used"`
Firstname string `json:"firstname"`
Lastname string `json:"lastname"`
}

43
model/search/search.go Normal file
View file

@ -0,0 +1,43 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package search
import "time"
// Search holds raw search results.
type Search struct {
ID string `json:"id"`
Created time.Time `json:"created"`
Revised time.Time `json:"revised"`
OrgID string
DocumentID string
Level uint64
Sequence float64
DocumentTitle string
Slug string
PageTitle string
Body string
}
// DocumentSearch represents 'presentable' search results.
type DocumentSearch struct {
ID string `json:"id"`
DocumentID string `json:"documentId"`
DocumentTitle string `json:"documentTitle"`
DocumentSlug string `json:"documentSlug"`
DocumentExcerpt string `json:"documentExcerpt"`
Tags string `json:"documentTags"`
PageTitle string `json:"pageTitle"`
LabelID string `json:"folderId"`
LabelName string `json:"folderName"`
FolderSlug string `json:"folderSlug"`
}

View file

@ -0,0 +1,54 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
package template
import "time"
// Template is used to create a new document.
// Template can consist of content, attachments and
// have associated meta data indentifying author, version
// contact details and more.
type Template struct {
ID string `json:"id"`
Title string `json:"title"`
Description string `json:"description"`
Author string `json:"author"`
Type Type `json:"type"`
Dated time.Time `json:"dated"`
}
// Type determines who can see a template.
type Type int
const (
// TypePublic means anyone can see the template.
TypePublic Type = 1
// TypePrivate means only the owner can see the template.
TypePrivate Type = 2
// TypeRestricted means selected users can see the template.
TypeRestricted Type = 3
)
// IsPublic means anyone can see the template.
func (t *Template) IsPublic() bool {
return t.Type == TypePublic
}
// IsPrivate means only the owner can see the template.
func (t *Template) IsPrivate() bool {
return t.Type == TypePrivate
}
// IsRestricted means selected users can see the template.
func (t *Template) IsRestricted() bool {
return t.Type == TypeRestricted
}

View file

@ -14,31 +14,48 @@ package routing
import ( import (
"net/http" "net/http"
"github.com/documize/community/core/api"
"github.com/documize/community/core/api/endpoint" "github.com/documize/community/core/api/endpoint"
"github.com/documize/community/core/env" "github.com/documize/community/core/env"
"github.com/documize/community/domain" "github.com/documize/community/domain"
"github.com/documize/community/domain/attachment"
"github.com/documize/community/domain/auth"
"github.com/documize/community/domain/link"
"github.com/documize/community/domain/meta"
"github.com/documize/community/domain/organization" "github.com/documize/community/domain/organization"
"github.com/documize/community/domain/pin" "github.com/documize/community/domain/pin"
"github.com/documize/community/domain/setting"
"github.com/documize/community/domain/space" "github.com/documize/community/domain/space"
"github.com/documize/community/domain/user"
"github.com/documize/community/server/web" "github.com/documize/community/server/web"
) )
// RegisterEndpoints register routes for serving API endpoints // RegisterEndpoints register routes for serving API endpoints
func RegisterEndpoints(rt *env.Runtime, s *domain.Store) { func RegisterEndpoints(rt *env.Runtime, s *domain.Store) {
// We pass server/application level contextual requirements into HTTP handlers
// DO NOT pass in per request context (that is done by auth middleware per request)
pin := pin.Handler{Runtime: rt, Store: s}
auth := auth.Handler{Runtime: rt, Store: s}
meta := meta.Handler{Runtime: rt, Store: s}
user := user.Handler{Runtime: rt, Store: s}
link := link.Handler{Runtime: rt, Store: s}
space := space.Handler{Runtime: rt, Store: s}
setting := setting.Handler{Runtime: rt, Store: s}
attachment := attachment.Handler{Runtime: rt, Store: s}
organization := organization.Handler{Runtime: rt, Store: s}
//************************************************** //**************************************************
// Non-secure routes // Non-secure routes
//************************************************** //**************************************************
Add(rt, RoutePrefixPublic, "meta", []string{"GET", "OPTIONS"}, nil, endpoint.GetMeta) Add(rt, RoutePrefixPublic, "meta", []string{"GET", "OPTIONS"}, nil, meta.Meta)
Add(rt, RoutePrefixPublic, "authenticate/keycloak", []string{"POST", "OPTIONS"}, nil, endpoint.AuthenticateKeycloak) Add(rt, RoutePrefixPublic, "authenticate/keycloak", []string{"POST", "OPTIONS"}, nil, endpoint.AuthenticateKeycloak)
Add(rt, RoutePrefixPublic, "authenticate", []string{"POST", "OPTIONS"}, nil, endpoint.Authenticate) Add(rt, RoutePrefixPublic, "authenticate", []string{"POST", "OPTIONS"}, nil, auth.Login)
Add(rt, RoutePrefixPublic, "validate", []string{"GET", "OPTIONS"}, nil, endpoint.ValidateAuthToken) Add(rt, RoutePrefixPublic, "validate", []string{"GET", "OPTIONS"}, nil, auth.ValidateToken)
Add(rt, RoutePrefixPublic, "forgot", []string{"POST", "OPTIONS"}, nil, endpoint.ForgotUserPassword) Add(rt, RoutePrefixPublic, "forgot", []string{"POST", "OPTIONS"}, nil, user.ForgotPassword)
Add(rt, RoutePrefixPublic, "reset/{token}", []string{"POST", "OPTIONS"}, nil, endpoint.ResetUserPassword) Add(rt, RoutePrefixPublic, "reset/{token}", []string{"POST", "OPTIONS"}, nil, user.ResetPassword)
Add(rt, RoutePrefixPublic, "share/{folderID}", []string{"POST", "OPTIONS"}, nil, endpoint.AcceptSharedFolder) Add(rt, RoutePrefixPublic, "share/{folderID}", []string{"POST", "OPTIONS"}, nil, space.AcceptInvitation)
Add(rt, RoutePrefixPublic, "attachments/{orgID}/{attachmentID}", []string{"GET", "OPTIONS"}, nil, endpoint.AttachmentDownload) Add(rt, RoutePrefixPublic, "attachments/{orgID}/{attachmentID}", []string{"GET", "OPTIONS"}, nil, attachment.Download)
Add(rt, RoutePrefixPublic, "version", []string{"GET", "OPTIONS"}, nil, func(w http.ResponseWriter, r *http.Request) { Add(rt, RoutePrefixPublic, "version", []string{"GET", "OPTIONS"}, nil, func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(api.Runtime.Product.Version)) w.Write([]byte(rt.Product.Version))
}) })
//************************************************** //**************************************************
@ -48,7 +65,6 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) {
// Import & Convert Document // Import & Convert Document
Add(rt, RoutePrefixPrivate, "import/folder/{folderID}", []string{"POST", "OPTIONS"}, nil, endpoint.UploadConvertDocument) Add(rt, RoutePrefixPrivate, "import/folder/{folderID}", []string{"POST", "OPTIONS"}, nil, endpoint.UploadConvertDocument)
// Document
Add(rt, RoutePrefixPrivate, "documents/{documentID}/export", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentAsDocx) Add(rt, RoutePrefixPrivate, "documents/{documentID}/export", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentAsDocx)
Add(rt, RoutePrefixPrivate, "documents", []string{"GET", "OPTIONS"}, []string{"filter", "tag"}, endpoint.GetDocumentsByTag) Add(rt, RoutePrefixPrivate, "documents", []string{"GET", "OPTIONS"}, []string{"filter", "tag"}, endpoint.GetDocumentsByTag)
Add(rt, RoutePrefixPrivate, "documents", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentsByFolder) Add(rt, RoutePrefixPrivate, "documents", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentsByFolder)
@ -57,7 +73,6 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) {
Add(rt, RoutePrefixPrivate, "documents/{documentID}", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteDocument) Add(rt, RoutePrefixPrivate, "documents/{documentID}", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteDocument)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/activity", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentActivity) Add(rt, RoutePrefixPrivate, "documents/{documentID}/activity", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentActivity)
// Document Page
Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/level", []string{"POST", "OPTIONS"}, nil, endpoint.ChangeDocumentPageLevel) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/level", []string{"POST", "OPTIONS"}, nil, endpoint.ChangeDocumentPageLevel)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/sequence", []string{"POST", "OPTIONS"}, nil, endpoint.ChangeDocumentPageSequence) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/sequence", []string{"POST", "OPTIONS"}, nil, endpoint.ChangeDocumentPageSequence)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/batch", []string{"POST", "OPTIONS"}, nil, endpoint.GetDocumentPagesBatch) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/batch", []string{"POST", "OPTIONS"}, nil, endpoint.GetDocumentPagesBatch)
@ -72,19 +87,15 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) {
Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteDocumentPages) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteDocumentPages)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/{pageID}", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentPage) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/{pageID}", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentPage)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages", []string{"POST", "OPTIONS"}, nil, endpoint.AddDocumentPage) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages", []string{"POST", "OPTIONS"}, nil, endpoint.AddDocumentPage)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments", []string{"GET", "OPTIONS"}, nil, endpoint.GetAttachments) Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments", []string{"GET", "OPTIONS"}, nil, attachment.Get)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments/{attachmentID}", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteAttachment) Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments/{attachmentID}", []string{"DELETE", "OPTIONS"}, nil, attachment.Delete)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments", []string{"POST", "OPTIONS"}, nil, endpoint.AddAttachments) Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments", []string{"POST", "OPTIONS"}, nil, attachment.Add)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/{pageID}/meta", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentPageMeta) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/{pageID}/meta", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentPageMeta)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/{pageID}/copy/{targetID}", []string{"POST", "OPTIONS"}, nil, endpoint.CopyPage) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/{pageID}/copy/{targetID}", []string{"POST", "OPTIONS"}, nil, endpoint.CopyPage)
// Organization
organization := organization.Handler{Runtime: rt, Store: s}
Add(rt, RoutePrefixPrivate, "organizations/{orgID}", []string{"GET", "OPTIONS"}, nil, organization.Get) Add(rt, RoutePrefixPrivate, "organizations/{orgID}", []string{"GET", "OPTIONS"}, nil, organization.Get)
Add(rt, RoutePrefixPrivate, "organizations/{orgID}", []string{"PUT", "OPTIONS"}, nil, organization.Update) Add(rt, RoutePrefixPrivate, "organizations/{orgID}", []string{"PUT", "OPTIONS"}, nil, organization.Update)
// Space
space := space.Handler{Runtime: rt, Store: s}
Add(rt, RoutePrefixPrivate, "folders/{folderID}", []string{"DELETE", "OPTIONS"}, nil, space.Delete) Add(rt, RoutePrefixPrivate, "folders/{folderID}", []string{"DELETE", "OPTIONS"}, nil, space.Delete)
Add(rt, RoutePrefixPrivate, "folders/{folderID}/move/{moveToId}", []string{"DELETE", "OPTIONS"}, nil, space.Remove) Add(rt, RoutePrefixPrivate, "folders/{folderID}/move/{moveToId}", []string{"DELETE", "OPTIONS"}, nil, space.Remove)
Add(rt, RoutePrefixPrivate, "folders/{folderID}/permissions", []string{"PUT", "OPTIONS"}, nil, space.SetPermissions) Add(rt, RoutePrefixPrivate, "folders/{folderID}/permissions", []string{"PUT", "OPTIONS"}, nil, space.SetPermissions)
@ -96,28 +107,24 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) {
Add(rt, RoutePrefixPrivate, "folders/{folderID}", []string{"GET", "OPTIONS"}, nil, space.Get) Add(rt, RoutePrefixPrivate, "folders/{folderID}", []string{"GET", "OPTIONS"}, nil, space.Get)
Add(rt, RoutePrefixPrivate, "folders/{folderID}", []string{"PUT", "OPTIONS"}, nil, space.Update) Add(rt, RoutePrefixPrivate, "folders/{folderID}", []string{"PUT", "OPTIONS"}, nil, space.Update)
// Users Add(rt, RoutePrefixPrivate, "users/{userID}/password", []string{"POST", "OPTIONS"}, nil, user.ChangePassword)
Add(rt, RoutePrefixPrivate, "users/{userID}/password", []string{"POST", "OPTIONS"}, nil, endpoint.ChangeUserPassword) Add(rt, RoutePrefixPrivate, "users/{userID}/permissions", []string{"GET", "OPTIONS"}, nil, user.UserSpacePermissions)
Add(rt, RoutePrefixPrivate, "users/{userID}/permissions", []string{"GET", "OPTIONS"}, nil, endpoint.GetUserFolderPermissions) Add(rt, RoutePrefixPrivate, "users", []string{"POST", "OPTIONS"}, nil, user.Add)
Add(rt, RoutePrefixPrivate, "users", []string{"POST", "OPTIONS"}, nil, endpoint.AddUser) Add(rt, RoutePrefixPrivate, "users/folder/{folderID}", []string{"GET", "OPTIONS"}, nil, user.GetSpaceUsers)
Add(rt, RoutePrefixPrivate, "users/folder/{folderID}", []string{"GET", "OPTIONS"}, nil, endpoint.GetFolderUsers) Add(rt, RoutePrefixPrivate, "users", []string{"GET", "OPTIONS"}, nil, user.GetOrganizationUsers)
Add(rt, RoutePrefixPrivate, "users", []string{"GET", "OPTIONS"}, nil, endpoint.GetOrganizationUsers) Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"GET", "OPTIONS"}, nil, user.Get)
Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"GET", "OPTIONS"}, nil, endpoint.GetUser) Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"PUT", "OPTIONS"}, nil, user.Update)
Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"PUT", "OPTIONS"}, nil, endpoint.UpdateUser) Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"DELETE", "OPTIONS"}, nil, user.Delete)
Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteUser)
Add(rt, RoutePrefixPrivate, "users/sync", []string{"GET", "OPTIONS"}, nil, endpoint.SyncKeycloak) Add(rt, RoutePrefixPrivate, "users/sync", []string{"GET", "OPTIONS"}, nil, endpoint.SyncKeycloak)
// Search
Add(rt, RoutePrefixPrivate, "search", []string{"GET", "OPTIONS"}, nil, endpoint.SearchDocuments) Add(rt, RoutePrefixPrivate, "search", []string{"GET", "OPTIONS"}, nil, endpoint.SearchDocuments)
// Templates
Add(rt, RoutePrefixPrivate, "templates", []string{"POST", "OPTIONS"}, nil, endpoint.SaveAsTemplate) Add(rt, RoutePrefixPrivate, "templates", []string{"POST", "OPTIONS"}, nil, endpoint.SaveAsTemplate)
Add(rt, RoutePrefixPrivate, "templates", []string{"GET", "OPTIONS"}, nil, endpoint.GetSavedTemplates) Add(rt, RoutePrefixPrivate, "templates", []string{"GET", "OPTIONS"}, nil, endpoint.GetSavedTemplates)
Add(rt, RoutePrefixPrivate, "templates/stock", []string{"GET", "OPTIONS"}, nil, endpoint.GetStockTemplates) Add(rt, RoutePrefixPrivate, "templates/stock", []string{"GET", "OPTIONS"}, nil, endpoint.GetStockTemplates)
Add(rt, RoutePrefixPrivate, "templates/{templateID}/folder/{folderID}", []string{"POST", "OPTIONS"}, []string{"type", "stock"}, endpoint.StartDocumentFromStockTemplate) Add(rt, RoutePrefixPrivate, "templates/{templateID}/folder/{folderID}", []string{"POST", "OPTIONS"}, []string{"type", "stock"}, endpoint.StartDocumentFromStockTemplate)
Add(rt, RoutePrefixPrivate, "templates/{templateID}/folder/{folderID}", []string{"POST", "OPTIONS"}, []string{"type", "saved"}, endpoint.StartDocumentFromSavedTemplate) Add(rt, RoutePrefixPrivate, "templates/{templateID}/folder/{folderID}", []string{"POST", "OPTIONS"}, []string{"type", "saved"}, endpoint.StartDocumentFromSavedTemplate)
// Sections
Add(rt, RoutePrefixPrivate, "sections", []string{"GET", "OPTIONS"}, nil, endpoint.GetSections) Add(rt, RoutePrefixPrivate, "sections", []string{"GET", "OPTIONS"}, nil, endpoint.GetSections)
Add(rt, RoutePrefixPrivate, "sections", []string{"POST", "OPTIONS"}, nil, endpoint.RunSectionCommand) Add(rt, RoutePrefixPrivate, "sections", []string{"POST", "OPTIONS"}, nil, endpoint.RunSectionCommand)
Add(rt, RoutePrefixPrivate, "sections/refresh", []string{"GET", "OPTIONS"}, nil, endpoint.RefreshSections) Add(rt, RoutePrefixPrivate, "sections/refresh", []string{"GET", "OPTIONS"}, nil, endpoint.RefreshSections)
@ -128,28 +135,23 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) {
Add(rt, RoutePrefixPrivate, "sections/blocks", []string{"POST", "OPTIONS"}, nil, endpoint.AddBlock) Add(rt, RoutePrefixPrivate, "sections/blocks", []string{"POST", "OPTIONS"}, nil, endpoint.AddBlock)
Add(rt, RoutePrefixPrivate, "sections/targets", []string{"GET", "OPTIONS"}, nil, endpoint.GetPageMoveCopyTargets) Add(rt, RoutePrefixPrivate, "sections/targets", []string{"GET", "OPTIONS"}, nil, endpoint.GetPageMoveCopyTargets)
// Links Add(rt, RoutePrefixPrivate, "links/{folderID}/{documentID}/{pageID}", []string{"GET", "OPTIONS"}, nil, link.GetLinkCandidates)
Add(rt, RoutePrefixPrivate, "links/{folderID}/{documentID}/{pageID}", []string{"GET", "OPTIONS"}, nil, endpoint.GetLinkCandidates) Add(rt, RoutePrefixPrivate, "links", []string{"GET", "OPTIONS"}, nil, link.SearchLinkCandidates)
Add(rt, RoutePrefixPrivate, "links", []string{"GET", "OPTIONS"}, nil, endpoint.SearchLinkCandidates)
Add(rt, RoutePrefixPrivate, "documents/{documentID}/links", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentLinks) Add(rt, RoutePrefixPrivate, "documents/{documentID}/links", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentLinks)
// Global installation-wide config Add(rt, RoutePrefixPrivate, "global/smtp", []string{"GET", "OPTIONS"}, nil, setting.SMTP)
Add(rt, RoutePrefixPrivate, "global/smtp", []string{"GET", "OPTIONS"}, nil, endpoint.GetSMTPConfig) Add(rt, RoutePrefixPrivate, "global/smtp", []string{"PUT", "OPTIONS"}, nil, setting.SetSMTP)
Add(rt, RoutePrefixPrivate, "global/smtp", []string{"PUT", "OPTIONS"}, nil, endpoint.SaveSMTPConfig) Add(rt, RoutePrefixPrivate, "global/license", []string{"GET", "OPTIONS"}, nil, setting.License)
Add(rt, RoutePrefixPrivate, "global/license", []string{"GET", "OPTIONS"}, nil, endpoint.GetLicense) Add(rt, RoutePrefixPrivate, "global/license", []string{"PUT", "OPTIONS"}, nil, setting.SetLicense)
Add(rt, RoutePrefixPrivate, "global/license", []string{"PUT", "OPTIONS"}, nil, endpoint.SaveLicense) Add(rt, RoutePrefixPrivate, "global/auth", []string{"GET", "OPTIONS"}, nil, setting.AuthConfig)
Add(rt, RoutePrefixPrivate, "global/auth", []string{"GET", "OPTIONS"}, nil, endpoint.GetAuthConfig) Add(rt, RoutePrefixPrivate, "global/auth", []string{"PUT", "OPTIONS"}, nil, setting.SetAuthConfig)
Add(rt, RoutePrefixPrivate, "global/auth", []string{"PUT", "OPTIONS"}, nil, endpoint.SaveAuthConfig)
// Pinned items
pin := pin.Handler{Runtime: rt, Store: s}
Add(rt, RoutePrefixPrivate, "pin/{userID}", []string{"POST", "OPTIONS"}, nil, pin.Add) Add(rt, RoutePrefixPrivate, "pin/{userID}", []string{"POST", "OPTIONS"}, nil, pin.Add)
Add(rt, RoutePrefixPrivate, "pin/{userID}", []string{"GET", "OPTIONS"}, nil, pin.GetUserPins) Add(rt, RoutePrefixPrivate, "pin/{userID}", []string{"GET", "OPTIONS"}, nil, pin.GetUserPins)
Add(rt, RoutePrefixPrivate, "pin/{userID}/sequence", []string{"POST", "OPTIONS"}, nil, pin.UpdatePinSequence) Add(rt, RoutePrefixPrivate, "pin/{userID}/sequence", []string{"POST", "OPTIONS"}, nil, pin.UpdatePinSequence)
Add(rt, RoutePrefixPrivate, "pin/{userID}/{pinID}", []string{"DELETE", "OPTIONS"}, nil, pin.DeleteUserPin) Add(rt, RoutePrefixPrivate, "pin/{userID}/{pinID}", []string{"DELETE", "OPTIONS"}, nil, pin.DeleteUserPin)
// Single page app handler Add(rt, RoutePrefixRoot, "robots.txt", []string{"GET", "OPTIONS"}, nil, meta.RobotsTxt)
Add(rt, RoutePrefixRoot, "robots.txt", []string{"GET", "OPTIONS"}, nil, endpoint.GetRobots) Add(rt, RoutePrefixRoot, "sitemap.xml", []string{"GET", "OPTIONS"}, nil, meta.Sitemap)
Add(rt, RoutePrefixRoot, "sitemap.xml", []string{"GET", "OPTIONS"}, nil, endpoint.GetSitemap)
Add(rt, RoutePrefixRoot, "{rest:.*}", nil, nil, web.EmberHandler) Add(rt, RoutePrefixRoot, "{rest:.*}", nil, nil, web.EmberHandler)
} }

View file

@ -1,4 +1,4 @@
#sqlx # sqlx
[![Build Status](https://drone.io/github.com/jmoiron/sqlx/status.png)](https://drone.io/github.com/jmoiron/sqlx/latest) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/sqlx) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/jmoiron/sqlx/master/LICENSE) [![Build Status](https://drone.io/github.com/jmoiron/sqlx/status.png)](https://drone.io/github.com/jmoiron/sqlx/latest) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/sqlx) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/jmoiron/sqlx/master/LICENSE)
@ -26,9 +26,7 @@ This breaks backwards compatibility, but it's in a way that is trivially fixable
(`s/JsonText/JSONText/g`). The `types` package is both experimental and not in (`s/JsonText/JSONText/g`). The `types` package is both experimental and not in
active development currently. active development currently.
More importantly, [golang bug #13905](https://github.com/golang/go/issues/13905) * Using Go 1.6 and below with `types.JSONText` and `types.GzippedText` can be _potentially unsafe_, **especially** when used with common auto-scan sqlx idioms like `Select` and `Get`. See [golang bug #13905](https://github.com/golang/go/issues/13905).
makes `types.JSONText` and `types.GzippedText` _potentially unsafe_, **especially**
when used with common auto-scan sqlx idioms like `Select` and `Get`.
### Backwards Compatibility ### Backwards Compatibility

View file

@ -27,7 +27,7 @@ func BindType(driverName string) int {
return QUESTION return QUESTION
case "sqlite3": case "sqlite3":
return QUESTION return QUESTION
case "oci8": case "oci8", "ora", "goracle":
return NAMED return NAMED
} }
return UNKNOWN return UNKNOWN
@ -43,27 +43,28 @@ func Rebind(bindType int, query string) string {
return query return query
} }
qb := []byte(query)
// Add space enough for 10 params before we have to allocate // Add space enough for 10 params before we have to allocate
rqb := make([]byte, 0, len(qb)+10) rqb := make([]byte, 0, len(query)+10)
j := 1
for _, b := range qb { var i, j int
if b == '?' {
switch bindType { for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") {
case DOLLAR: rqb = append(rqb, query[:i]...)
rqb = append(rqb, '$')
case NAMED: switch bindType {
rqb = append(rqb, ':', 'a', 'r', 'g') case DOLLAR:
} rqb = append(rqb, '$')
for _, b := range strconv.Itoa(j) { case NAMED:
rqb = append(rqb, byte(b)) rqb = append(rqb, ':', 'a', 'r', 'g')
}
j++
} else {
rqb = append(rqb, b)
} }
j++
rqb = strconv.AppendInt(rqb, int64(j), 10)
query = query[i+1:]
} }
return string(rqb)
return string(append(rqb, query...))
} }
// Experimental implementation of Rebind which uses a bytes.Buffer. The code is // Experimental implementation of Rebind which uses a bytes.Buffer. The code is
@ -135,9 +136,9 @@ func In(query string, args ...interface{}) (string, []interface{}, error) {
} }
newArgs := make([]interface{}, 0, flatArgsCount) newArgs := make([]interface{}, 0, flatArgsCount)
buf := bytes.NewBuffer(make([]byte, 0, len(query)+len(", ?")*flatArgsCount))
var arg, offset int var arg, offset int
var buf bytes.Buffer
for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') {
if arg >= len(meta) { if arg >= len(meta) {
@ -163,13 +164,12 @@ func In(query string, args ...interface{}) (string, []interface{}, error) {
// write everything up to and including our ? character // write everything up to and including our ? character
buf.WriteString(query[:offset+i+1]) buf.WriteString(query[:offset+i+1])
newArgs = append(newArgs, argMeta.v.Index(0).Interface())
for si := 1; si < argMeta.length; si++ { for si := 1; si < argMeta.length; si++ {
buf.WriteString(", ?") buf.WriteString(", ?")
newArgs = append(newArgs, argMeta.v.Index(si).Interface())
} }
newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length)
// slice the query and reset the offset. this avoids some bookkeeping for // slice the query and reset the offset. this avoids some bookkeeping for
// the write after the loop // the write after the loop
query = query[offset+i+1:] query = query[offset+i+1:]
@ -184,3 +184,24 @@ func In(query string, args ...interface{}) (string, []interface{}, error) {
return buf.String(), newArgs, nil return buf.String(), newArgs, nil
} }
func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} {
switch val := v.Interface().(type) {
case []interface{}:
args = append(args, val...)
case []int:
for i := range val {
args = append(args, val[i])
}
case []string:
for i := range val {
args = append(args, val[i])
}
default:
for si := 0; si < vlen; si++ {
args = append(args, v.Index(si).Interface())
}
}
return args
}

View file

@ -36,6 +36,7 @@ func (n *NamedStmt) Close() error {
} }
// Exec executes a named statement using the struct passed. // Exec executes a named statement using the struct passed.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil { if err != nil {
@ -45,6 +46,7 @@ func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) {
} }
// Query executes a named statement using the struct argument, returning rows. // Query executes a named statement using the struct argument, returning rows.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil { if err != nil {
@ -56,6 +58,7 @@ func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) {
// QueryRow executes a named statement against the database. Because sqlx cannot // QueryRow executes a named statement against the database. Because sqlx cannot
// create a *sql.Row with an error condition pre-set for binding errors, sqlx // create a *sql.Row with an error condition pre-set for binding errors, sqlx
// returns a *sqlx.Row instead. // returns a *sqlx.Row instead.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryRow(arg interface{}) *Row { func (n *NamedStmt) QueryRow(arg interface{}) *Row {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil { if err != nil {
@ -65,6 +68,7 @@ func (n *NamedStmt) QueryRow(arg interface{}) *Row {
} }
// MustExec execs a NamedStmt, panicing on error // MustExec execs a NamedStmt, panicing on error
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) MustExec(arg interface{}) sql.Result { func (n *NamedStmt) MustExec(arg interface{}) sql.Result {
res, err := n.Exec(arg) res, err := n.Exec(arg)
if err != nil { if err != nil {
@ -74,6 +78,7 @@ func (n *NamedStmt) MustExec(arg interface{}) sql.Result {
} }
// Queryx using this NamedStmt // Queryx using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) {
r, err := n.Query(arg) r, err := n.Query(arg)
if err != nil { if err != nil {
@ -84,11 +89,13 @@ func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) {
// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is // QueryRowx this NamedStmt. Because of limitations with QueryRow, this is
// an alias for QueryRow. // an alias for QueryRow.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryRowx(arg interface{}) *Row { func (n *NamedStmt) QueryRowx(arg interface{}) *Row {
return n.QueryRow(arg) return n.QueryRow(arg)
} }
// Select using this NamedStmt // Select using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { func (n *NamedStmt) Select(dest interface{}, arg interface{}) error {
rows, err := n.Queryx(arg) rows, err := n.Queryx(arg)
if err != nil { if err != nil {
@ -100,6 +107,7 @@ func (n *NamedStmt) Select(dest interface{}, arg interface{}) error {
} }
// Get using this NamedStmt // Get using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { func (n *NamedStmt) Get(dest interface{}, arg interface{}) error {
r := n.QueryRowx(arg) r := n.QueryRowx(arg)
return r.scanAny(dest, false) return r.scanAny(dest, false)
@ -250,7 +258,7 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e
inName = true inName = true
name = []byte{} name = []byte{}
// if we're in a name, and this is an allowed character, continue // if we're in a name, and this is an allowed character, continue
} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_') && i != last { } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last {
// append the byte to the name if we are in a name and not on the last byte // append the byte to the name if we are in a name and not on the last byte
name = append(name, b) name = append(name, b)
// if we're in a name and it's not an allowed character, the name is done // if we're in a name and it's not an allowed character, the name is done

132
vendor/github.com/jmoiron/sqlx/named_context.go generated vendored Normal file
View file

@ -0,0 +1,132 @@
// +build go1.8
package sqlx
import (
"context"
"database/sql"
)
// A union interface of contextPreparer and binder, required to be able to
// prepare named statements with context (as the bindtype must be determined).
type namedPreparerContext interface {
PreparerContext
binder
}
func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) {
bindType := BindType(p.DriverName())
q, args, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return nil, err
}
stmt, err := PreparexContext(ctx, p, q)
if err != nil {
return nil, err
}
return &NamedStmt{
QueryString: q,
Params: args,
Stmt: stmt,
}, nil
}
// ExecContext executes a named statement using the struct passed.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return *new(sql.Result), err
}
return n.Stmt.ExecContext(ctx, args...)
}
// QueryContext executes a named statement using the struct argument, returning rows.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return nil, err
}
return n.Stmt.QueryContext(ctx, args...)
}
// QueryRowContext executes a named statement against the database. Because sqlx cannot
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
// returns a *sqlx.Row instead.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return &Row{err: err}
}
return n.Stmt.QueryRowxContext(ctx, args...)
}
// MustExecContext execs a NamedStmt, panicing on error
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result {
res, err := n.ExecContext(ctx, arg)
if err != nil {
panic(err)
}
return res
}
// QueryxContext using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) {
r, err := n.QueryContext(ctx, arg)
if err != nil {
return nil, err
}
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err
}
// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is
// an alias for QueryRow.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row {
return n.QueryRowContext(ctx, arg)
}
// SelectContext using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error {
rows, err := n.QueryxContext(ctx, arg)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// GetContext using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error {
r := n.QueryRowxContext(ctx, arg)
return r.scanAny(dest, false)
}
// NamedQueryContext binds a named query and then runs Query on the result using the
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
// map[string]interface{} types.
func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.QueryxContext(ctx, q, args...)
}
// NamedExecContext uses BindStruct to get a query executable by the driver and
// then runs Exec on the result. Returns an error from the binding
// or the query excution itself.
func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.ExecContext(ctx, q, args...)
}

136
vendor/github.com/jmoiron/sqlx/named_context_test.go generated vendored Normal file
View file

@ -0,0 +1,136 @@
// +build go1.8
package sqlx
import (
"context"
"database/sql"
"testing"
)
func TestNamedContextQueries(t *testing.T) {
RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
loadDefaultFixture(db, t)
test := Test{t}
var ns *NamedStmt
var err error
ctx := context.Background()
// Check that invalid preparations fail
ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name")
if err == nil {
t.Error("Expected an error with invalid prepared statement.")
}
ns, err = db.PrepareNamedContext(ctx, "invalid sql")
if err == nil {
t.Error("Expected an error with invalid prepared statement.")
}
// Check closing works as anticipated
ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name")
test.Error(err)
err = ns.Close()
test.Error(err)
ns, err = db.PrepareNamedContext(ctx, `
SELECT first_name, last_name, email
FROM person WHERE first_name=:first_name AND email=:email`)
test.Error(err)
// test Queryx w/ uses Query
p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"}
rows, err := ns.QueryxContext(ctx, p)
test.Error(err)
for rows.Next() {
var p2 Person
rows.StructScan(&p2)
if p.FirstName != p2.FirstName {
t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName)
}
if p.LastName != p2.LastName {
t.Errorf("got %s, expected %s", p.LastName, p2.LastName)
}
if p.Email != p2.Email {
t.Errorf("got %s, expected %s", p.Email, p2.Email)
}
}
// test Select
people := make([]Person, 0, 5)
err = ns.SelectContext(ctx, &people, p)
test.Error(err)
if len(people) != 1 {
t.Errorf("got %d results, expected %d", len(people), 1)
}
if p.FirstName != people[0].FirstName {
t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName)
}
if p.LastName != people[0].LastName {
t.Errorf("got %s, expected %s", p.LastName, people[0].LastName)
}
if p.Email != people[0].Email {
t.Errorf("got %s, expected %s", p.Email, people[0].Email)
}
// test Exec
ns, err = db.PrepareNamedContext(ctx, `
INSERT INTO person (first_name, last_name, email)
VALUES (:first_name, :last_name, :email)`)
test.Error(err)
js := Person{
FirstName: "Julien",
LastName: "Savea",
Email: "jsavea@ab.co.nz",
}
_, err = ns.ExecContext(ctx, js)
test.Error(err)
// Make sure we can pull him out again
p2 := Person{}
db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email)
if p2.Email != js.Email {
t.Errorf("expected %s, got %s", js.Email, p2.Email)
}
// test Txn NamedStmts
tx := db.MustBeginTx(ctx, nil)
txns := tx.NamedStmtContext(ctx, ns)
// We're going to add Steven in this txn
sl := Person{
FirstName: "Steven",
LastName: "Luatua",
Email: "sluatua@ab.co.nz",
}
_, err = txns.ExecContext(ctx, sl)
test.Error(err)
// then rollback...
tx.Rollback()
// looking for Steven after a rollback should fail
err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
if err != sql.ErrNoRows {
t.Errorf("expected no rows error, got %v", err)
}
// now do the same, but commit
tx = db.MustBeginTx(ctx, nil)
txns = tx.NamedStmtContext(ctx, ns)
_, err = txns.ExecContext(ctx, sl)
test.Error(err)
tx.Commit()
// looking for Steven after a Commit should succeed
err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
test.Error(err)
if p2.Email != sl.Email {
t.Errorf("expected %s, got %s", sl.Email, p2.Email)
}
})
}

View file

@ -12,6 +12,6 @@ behavior of standard Go accessors.
The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is
addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct
tags in the ways that are vital to most marshalers, and they are slow. tags in the ways that are vital to most marshallers, and they are slow.
This reflectx package extends reflect to achieve these goals. This reflectx package extends reflect to achieve these goals.

View file

@ -1,5 +1,5 @@
// Package reflectx implements extensions to the standard reflect lib suitable // Package reflectx implements extensions to the standard reflect lib suitable
// for implementing marshaling and unmarshaling packages. The main Mapper type // for implementing marshalling and unmarshalling packages. The main Mapper type
// allows for Go-compatible named attribute access, including accessing embedded // allows for Go-compatible named attribute access, including accessing embedded
// struct attributes and the ability to use functions and struct tags to // struct attributes and the ability to use functions and struct tags to
// customize field names. // customize field names.
@ -7,14 +7,13 @@
package reflectx package reflectx
import ( import (
"fmt"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
) )
// A FieldInfo is a collection of metadata about a struct field. // A FieldInfo is metadata for a struct field.
type FieldInfo struct { type FieldInfo struct {
Index []int Index []int
Path string Path string
@ -41,7 +40,8 @@ func (f StructMap) GetByPath(path string) *FieldInfo {
} }
// GetByTraversal returns a *FieldInfo for a given integer path. It is // GetByTraversal returns a *FieldInfo for a given integer path. It is
// analogous to reflect.FieldByIndex. // analogous to reflect.FieldByIndex, but using the cached traversal
// rather than re-executing the reflect machinery each time.
func (f StructMap) GetByTraversal(index []int) *FieldInfo { func (f StructMap) GetByTraversal(index []int) *FieldInfo {
if len(index) == 0 { if len(index) == 0 {
return nil return nil
@ -58,8 +58,8 @@ func (f StructMap) GetByTraversal(index []int) *FieldInfo {
} }
// Mapper is a general purpose mapper of names to struct fields. A Mapper // Mapper is a general purpose mapper of names to struct fields. A Mapper
// behaves like most marshallers, optionally obeying a field tag for name // behaves like most marshallers in the standard library, obeying a field tag
// mapping and a function to provide a basic mapping of fields to names. // for name mapping but also providing a basic transform function.
type Mapper struct { type Mapper struct {
cache map[reflect.Type]*StructMap cache map[reflect.Type]*StructMap
tagName string tagName string
@ -68,8 +68,8 @@ type Mapper struct {
mutex sync.Mutex mutex sync.Mutex
} }
// NewMapper returns a new mapper which optionally obeys the field tag given // NewMapper returns a new mapper using the tagName as its struct field tag.
// by tagName. If tagName is the empty string, it is ignored. // If tagName is the empty string, it is ignored.
func NewMapper(tagName string) *Mapper { func NewMapper(tagName string) *Mapper {
return &Mapper{ return &Mapper{
cache: make(map[reflect.Type]*StructMap), cache: make(map[reflect.Type]*StructMap),
@ -127,7 +127,7 @@ func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value {
return r return r
} }
// FieldByName returns a field by the its mapped name as a reflect.Value. // FieldByName returns a field by its mapped name as a reflect.Value.
// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. // Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind.
// Returns zero Value if the name is not found. // Returns zero Value if the name is not found.
func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value {
@ -182,11 +182,12 @@ func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
return r return r
} }
// FieldByIndexes returns a value for a particular struct traversal. // FieldByIndexes returns a value for the field given by the struct traversal
// for the given value.
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes { for _, i := range indexes {
v = reflect.Indirect(v).Field(i) v = reflect.Indirect(v).Field(i)
// if this is a pointer, it's possible it is nil // if this is a pointer and it's nil, allocate a new value and set it
if v.Kind() == reflect.Ptr && v.IsNil() { if v.Kind() == reflect.Ptr && v.IsNil() {
alloc := reflect.New(Deref(v.Type())) alloc := reflect.New(Deref(v.Type()))
v.Set(alloc) v.Set(alloc)
@ -225,13 +226,12 @@ type kinder interface {
// mustBe checks a value against a kind, panicing with a reflect.ValueError // mustBe checks a value against a kind, panicing with a reflect.ValueError
// if the kind isn't that which is required. // if the kind isn't that which is required.
func mustBe(v kinder, expected reflect.Kind) { func mustBe(v kinder, expected reflect.Kind) {
k := v.Kind() if k := v.Kind(); k != expected {
if k != expected {
panic(&reflect.ValueError{Method: methodName(), Kind: k}) panic(&reflect.ValueError{Method: methodName(), Kind: k})
} }
} }
// methodName is returns the caller of the function calling methodName // methodName returns the caller of the function calling methodName
func methodName() string { func methodName() string {
pc, _, _, _ := runtime.Caller(2) pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc) f := runtime.FuncForPC(pc)
@ -257,19 +257,92 @@ func apnd(is []int, i int) []int {
return x return x
} }
type mapf func(string) string
// parseName parses the tag and the target name for the given field using
// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the
// field's name to a target name, and tagMapFunc for mapping the tag to
// a target name.
func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) {
// first, set the fieldName to the field's name
fieldName = field.Name
// if a mapFunc is set, use that to override the fieldName
if mapFunc != nil {
fieldName = mapFunc(fieldName)
}
// if there's no tag to look for, return the field name
if tagName == "" {
return "", fieldName
}
// if this tag is not set using the normal convention in the tag,
// then return the fieldname.. this check is done because according
// to the reflect documentation:
// If the tag does not have the conventional format,
// the value returned by Get is unspecified.
// which doesn't sound great.
if !strings.Contains(string(field.Tag), tagName+":") {
return "", fieldName
}
// at this point we're fairly sure that we have a tag, so lets pull it out
tag = field.Tag.Get(tagName)
// if we have a mapper function, call it on the whole tag
// XXX: this is a change from the old version, which pulled out the name
// before the tagMapFunc could be run, but I think this is the right way
if tagMapFunc != nil {
tag = tagMapFunc(tag)
}
// finally, split the options from the name
parts := strings.Split(tag, ",")
fieldName = parts[0]
return tag, fieldName
}
// parseOptions parses options out of a tag string, skipping the name
func parseOptions(tag string) map[string]string {
parts := strings.Split(tag, ",")
options := make(map[string]string, len(parts))
if len(parts) > 1 {
for _, opt := range parts[1:] {
// short circuit potentially expensive split op
if strings.Contains(opt, "=") {
kv := strings.Split(opt, "=")
options[kv[0]] = kv[1]
continue
}
options[opt] = ""
}
}
return options
}
// getMapping returns a mapping for the t type, using the tagName, mapFunc and // getMapping returns a mapping for the t type, using the tagName, mapFunc and
// tagMapFunc to determine the canonical names of fields. // tagMapFunc to determine the canonical names of fields.
func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) *StructMap { func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap {
m := []*FieldInfo{} m := []*FieldInfo{}
root := &FieldInfo{} root := &FieldInfo{}
queue := []typeQueue{} queue := []typeQueue{}
queue = append(queue, typeQueue{Deref(t), root, ""}) queue = append(queue, typeQueue{Deref(t), root, ""})
QueueLoop:
for len(queue) != 0 { for len(queue) != 0 {
// pop the first item off of the queue // pop the first item off of the queue
tq := queue[0] tq := queue[0]
queue = queue[1:] queue = queue[1:]
// ignore recursive field
for p := tq.fi.Parent; p != nil; p = p.Parent {
if tq.fi.Field.Type == p.Field.Type {
continue QueueLoop
}
}
nChildren := 0 nChildren := 0
if tq.t.Kind() == reflect.Struct { if tq.t.Kind() == reflect.Struct {
nChildren = tq.t.NumField() nChildren = tq.t.NumField()
@ -278,53 +351,31 @@ func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string)
// iterate through all of its fields // iterate through all of its fields
for fieldPos := 0; fieldPos < nChildren; fieldPos++ { for fieldPos := 0; fieldPos < nChildren; fieldPos++ {
f := tq.t.Field(fieldPos) f := tq.t.Field(fieldPos)
fi := FieldInfo{} // parse the tag and the target name using the mapping options for this field
fi.Field = f tag, name := parseName(f, tagName, mapFunc, tagMapFunc)
fi.Zero = reflect.New(f.Type).Elem()
fi.Options = map[string]string{}
var tag, name string
if tagName != "" && strings.Contains(string(f.Tag), tagName+":") {
tag = f.Tag.Get(tagName)
name = tag
} else {
if mapFunc != nil {
name = mapFunc(f.Name)
}
}
parts := strings.Split(name, ",")
if len(parts) > 1 {
name = parts[0]
for _, opt := range parts[1:] {
kv := strings.Split(opt, "=")
if len(kv) > 1 {
fi.Options[kv[0]] = kv[1]
} else {
fi.Options[kv[0]] = ""
}
}
}
if tagMapFunc != nil {
tag = tagMapFunc(tag)
}
fi.Name = name
if tq.pp == "" || (tq.pp == "" && tag == "") {
fi.Path = fi.Name
} else {
fi.Path = fmt.Sprintf("%s.%s", tq.pp, fi.Name)
}
// if the name is "-", disabled via a tag, skip it // if the name is "-", disabled via a tag, skip it
if name == "-" { if name == "-" {
continue continue
} }
fi := FieldInfo{
Field: f,
Name: name,
Zero: reflect.New(f.Type).Elem(),
Options: parseOptions(tag),
}
// if the path is empty this path is just the name
if tq.pp == "" {
fi.Path = fi.Name
} else {
fi.Path = tq.pp + "." + fi.Name
}
// skip unexported fields // skip unexported fields
if len(f.PkgPath) != 0 && !f.Anonymous { if len(f.PkgPath) != 0 && !f.Anonymous {
continue continue

View file

@ -247,11 +247,20 @@ func TestInlineStruct(t *testing.T) {
} }
} }
func TestRecursiveStruct(t *testing.T) {
type Person struct {
Parent *Person
}
m := NewMapperFunc("db", strings.ToLower)
var p *Person
m.TypeMap(reflect.TypeOf(p))
}
func TestFieldsEmbedded(t *testing.T) { func TestFieldsEmbedded(t *testing.T) {
m := NewMapper("db") m := NewMapper("db")
type Person struct { type Person struct {
Name string `db:"name"` Name string `db:"name,size=64"`
} }
type Place struct { type Place struct {
Name string `db:"name"` Name string `db:"name"`
@ -311,6 +320,9 @@ func TestFieldsEmbedded(t *testing.T) {
if fi.Path != "person.name" { if fi.Path != "person.name" {
t.Errorf("Expecting %s, got %s", "person.name", fi.Path) t.Errorf("Expecting %s, got %s", "person.name", fi.Path)
} }
if fi.Options["size"] != "64" {
t.Errorf("Expecting %s, got %s", "64", fi.Options["size"])
}
fi = fields.GetByTraversal([]int{1, 0}) fi = fields.GetByTraversal([]int{1, 0})
if fi == nil { if fi == nil {
@ -508,6 +520,312 @@ func TestMapping(t *testing.T) {
} }
} }
func TestGetByTraversal(t *testing.T) {
type C struct {
C0 int
C1 int
}
type B struct {
B0 string
B1 *C
}
type A struct {
A0 int
A1 B
}
testCases := []struct {
Index []int
ExpectedName string
ExpectNil bool
}{
{
Index: []int{0},
ExpectedName: "A0",
},
{
Index: []int{1, 0},
ExpectedName: "B0",
},
{
Index: []int{1, 1, 1},
ExpectedName: "C1",
},
{
Index: []int{3, 4, 5},
ExpectNil: true,
},
{
Index: []int{},
ExpectNil: true,
},
{
Index: nil,
ExpectNil: true,
},
}
m := NewMapperFunc("db", func(n string) string { return n })
tm := m.TypeMap(reflect.TypeOf(A{}))
for i, tc := range testCases {
fi := tm.GetByTraversal(tc.Index)
if tc.ExpectNil {
if fi != nil {
t.Errorf("%d: expected nil, got %v", i, fi)
}
continue
}
if fi == nil {
t.Errorf("%d: expected %s, got nil", i, tc.ExpectedName)
continue
}
if fi.Name != tc.ExpectedName {
t.Errorf("%d: expected %s, got %s", i, tc.ExpectedName, fi.Name)
}
}
}
// TestMapperMethodsByName tests Mapper methods FieldByName and TraversalsByName
func TestMapperMethodsByName(t *testing.T) {
type C struct {
C0 string
C1 int
}
type B struct {
B0 *C `db:"B0"`
B1 C `db:"B1"`
B2 string `db:"B2"`
}
type A struct {
A0 *B `db:"A0"`
B `db:"A1"`
A2 int
a3 int
}
val := &A{
A0: &B{
B0: &C{C0: "0", C1: 1},
B1: C{C0: "2", C1: 3},
B2: "4",
},
B: B{
B0: nil,
B1: C{C0: "5", C1: 6},
B2: "7",
},
A2: 8,
}
testCases := []struct {
Name string
ExpectInvalid bool
ExpectedValue interface{}
ExpectedIndexes []int
}{
{
Name: "A0.B0.C0",
ExpectedValue: "0",
ExpectedIndexes: []int{0, 0, 0},
},
{
Name: "A0.B0.C1",
ExpectedValue: 1,
ExpectedIndexes: []int{0, 0, 1},
},
{
Name: "A0.B1.C0",
ExpectedValue: "2",
ExpectedIndexes: []int{0, 1, 0},
},
{
Name: "A0.B1.C1",
ExpectedValue: 3,
ExpectedIndexes: []int{0, 1, 1},
},
{
Name: "A0.B2",
ExpectedValue: "4",
ExpectedIndexes: []int{0, 2},
},
{
Name: "A1.B0.C0",
ExpectedValue: "",
ExpectedIndexes: []int{1, 0, 0},
},
{
Name: "A1.B0.C1",
ExpectedValue: 0,
ExpectedIndexes: []int{1, 0, 1},
},
{
Name: "A1.B1.C0",
ExpectedValue: "5",
ExpectedIndexes: []int{1, 1, 0},
},
{
Name: "A1.B1.C1",
ExpectedValue: 6,
ExpectedIndexes: []int{1, 1, 1},
},
{
Name: "A1.B2",
ExpectedValue: "7",
ExpectedIndexes: []int{1, 2},
},
{
Name: "A2",
ExpectedValue: 8,
ExpectedIndexes: []int{2},
},
{
Name: "XYZ",
ExpectInvalid: true,
ExpectedIndexes: []int{},
},
{
Name: "a3",
ExpectInvalid: true,
ExpectedIndexes: []int{},
},
}
// build the names array from the test cases
names := make([]string, len(testCases))
for i, tc := range testCases {
names[i] = tc.Name
}
m := NewMapperFunc("db", func(n string) string { return n })
v := reflect.ValueOf(val)
values := m.FieldsByName(v, names)
if len(values) != len(testCases) {
t.Errorf("expected %d values, got %d", len(testCases), len(values))
t.FailNow()
}
indexes := m.TraversalsByName(v.Type(), names)
if len(indexes) != len(testCases) {
t.Errorf("expected %d traversals, got %d", len(testCases), len(indexes))
t.FailNow()
}
for i, val := range values {
tc := testCases[i]
traversal := indexes[i]
if !reflect.DeepEqual(tc.ExpectedIndexes, traversal) {
t.Errorf("expected %v, got %v", tc.ExpectedIndexes, traversal)
t.FailNow()
}
val = reflect.Indirect(val)
if tc.ExpectInvalid {
if val.IsValid() {
t.Errorf("%d: expected zero value, got %v", i, val)
}
continue
}
if !val.IsValid() {
t.Errorf("%d: expected valid value, got %v", i, val)
continue
}
actualValue := reflect.Indirect(val).Interface()
if !reflect.DeepEqual(tc.ExpectedValue, actualValue) {
t.Errorf("%d: expected %v, got %v", i, tc.ExpectedValue, actualValue)
}
}
}
func TestFieldByIndexes(t *testing.T) {
type C struct {
C0 bool
C1 string
C2 int
C3 map[string]int
}
type B struct {
B1 C
B2 *C
}
type A struct {
A1 B
A2 *B
}
testCases := []struct {
value interface{}
indexes []int
expectedValue interface{}
readOnly bool
}{
{
value: A{
A1: B{B1: C{C0: true}},
},
indexes: []int{0, 0, 0},
expectedValue: true,
readOnly: true,
},
{
value: A{
A2: &B{B2: &C{C1: "answer"}},
},
indexes: []int{1, 1, 1},
expectedValue: "answer",
readOnly: true,
},
{
value: &A{},
indexes: []int{1, 1, 3},
expectedValue: map[string]int{},
},
}
for i, tc := range testCases {
checkResults := func(v reflect.Value) {
if tc.expectedValue == nil {
if !v.IsNil() {
t.Errorf("%d: expected nil, actual %v", i, v.Interface())
}
} else {
if !reflect.DeepEqual(tc.expectedValue, v.Interface()) {
t.Errorf("%d: expected %v, actual %v", i, tc.expectedValue, v.Interface())
}
}
}
checkResults(FieldByIndexes(reflect.ValueOf(tc.value), tc.indexes))
if tc.readOnly {
checkResults(FieldByIndexesReadOnly(reflect.ValueOf(tc.value), tc.indexes))
}
}
}
func TestMustBe(t *testing.T) {
typ := reflect.TypeOf(E1{})
mustBe(typ, reflect.Struct)
defer func() {
if r := recover(); r != nil {
valueErr, ok := r.(*reflect.ValueError)
if !ok {
t.Errorf("unexpected Method: %s", valueErr.Method)
t.Error("expected panic with *reflect.ValueError")
return
}
if valueErr.Method != "github.com/jmoiron/sqlx/reflectx.TestMustBe" {
}
if valueErr.Kind != reflect.String {
t.Errorf("unexpected Kind: %s", valueErr.Kind)
}
} else {
t.Error("expected panic")
}
}()
typ = reflect.TypeOf("string")
mustBe(typ, reflect.Struct)
t.Error("got here, didn't expect to")
}
type E1 struct { type E1 struct {
A int A int
} }

View file

@ -10,6 +10,7 @@ import (
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings" "strings"
"sync"
"github.com/jmoiron/sqlx/reflectx" "github.com/jmoiron/sqlx/reflectx"
) )
@ -17,7 +18,7 @@ import (
// Although the NameMapper is convenient, in practice it should not // Although the NameMapper is convenient, in practice it should not
// be relied on except for application code. If you are writing a library // be relied on except for application code. If you are writing a library
// that uses sqlx, you should be aware that the name mappings you expect // that uses sqlx, you should be aware that the name mappings you expect
// can be overridded by your user's application. // can be overridden by your user's application.
// NameMapper is used to map column names to struct field names. By default, // NameMapper is used to map column names to struct field names. By default,
// it uses strings.ToLower to lowercase struct field names. It can be set // it uses strings.ToLower to lowercase struct field names. It can be set
@ -30,8 +31,14 @@ var origMapper = reflect.ValueOf(NameMapper)
// importers have time to customize the NameMapper. // importers have time to customize the NameMapper.
var mpr *reflectx.Mapper var mpr *reflectx.Mapper
// mprMu protects mpr.
var mprMu sync.Mutex
// mapper returns a valid mapper using the configured NameMapper func. // mapper returns a valid mapper using the configured NameMapper func.
func mapper() *reflectx.Mapper { func mapper() *reflectx.Mapper {
mprMu.Lock()
defer mprMu.Unlock()
if mpr == nil { if mpr == nil {
mpr = reflectx.NewMapperFunc("db", NameMapper) mpr = reflectx.NewMapperFunc("db", NameMapper)
} else if origMapper != reflect.ValueOf(NameMapper) { } else if origMapper != reflect.ValueOf(NameMapper) {
@ -289,21 +296,26 @@ func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, e
} }
// NamedQuery using this DB. // NamedQuery using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) {
return NamedQuery(db, query, arg) return NamedQuery(db, query, arg)
} }
// NamedExec using this DB. // NamedExec using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) {
return NamedExec(db, query, arg) return NamedExec(db, query, arg)
} }
// Select using this DB. // Select using this DB.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { func (db *DB) Select(dest interface{}, query string, args ...interface{}) error {
return Select(db, dest, query, args...) return Select(db, dest, query, args...)
} }
// Get using this DB. // Get using this DB.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { func (db *DB) Get(dest interface{}, query string, args ...interface{}) error {
return Get(db, dest, query, args...) return Get(db, dest, query, args...)
} }
@ -328,6 +340,7 @@ func (db *DB) Beginx() (*Tx, error) {
} }
// Queryx queries the database and returns an *sqlx.Rows. // Queryx queries the database and returns an *sqlx.Rows.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) {
r, err := db.DB.Query(query, args...) r, err := db.DB.Query(query, args...)
if err != nil { if err != nil {
@ -337,12 +350,14 @@ func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) {
} }
// QueryRowx queries the database and returns an *sqlx.Row. // QueryRowx queries the database and returns an *sqlx.Row.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) QueryRowx(query string, args ...interface{}) *Row { func (db *DB) QueryRowx(query string, args ...interface{}) *Row {
rows, err := db.DB.Query(query, args...) rows, err := db.DB.Query(query, args...)
return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper}
} }
// MustExec (panic) runs MustExec using this database. // MustExec (panic) runs MustExec using this database.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) MustExec(query string, args ...interface{}) sql.Result { func (db *DB) MustExec(query string, args ...interface{}) sql.Result {
return MustExec(db, query, args...) return MustExec(db, query, args...)
} }
@ -387,21 +402,25 @@ func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, e
} }
// NamedQuery within a transaction. // NamedQuery within a transaction.
// Any named placeholder parameters are replaced with fields from arg.
func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) {
return NamedQuery(tx, query, arg) return NamedQuery(tx, query, arg)
} }
// NamedExec a named query within a transaction. // NamedExec a named query within a transaction.
// Any named placeholder parameters are replaced with fields from arg.
func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) {
return NamedExec(tx, query, arg) return NamedExec(tx, query, arg)
} }
// Select within a transaction. // Select within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error {
return Select(tx, dest, query, args...) return Select(tx, dest, query, args...)
} }
// Queryx within a transaction. // Queryx within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) {
r, err := tx.Tx.Query(query, args...) r, err := tx.Tx.Query(query, args...)
if err != nil { if err != nil {
@ -411,17 +430,21 @@ func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) {
} }
// QueryRowx within a transaction. // QueryRowx within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row {
rows, err := tx.Tx.Query(query, args...) rows, err := tx.Tx.Query(query, args...)
return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper}
} }
// Get within a transaction. // Get within a transaction.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error {
return Get(tx, dest, query, args...) return Get(tx, dest, query, args...)
} }
// MustExec runs MustExec within a transaction. // MustExec runs MustExec within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result {
return MustExec(tx, query, args...) return MustExec(tx, query, args...)
} }
@ -478,28 +501,34 @@ func (s *Stmt) Unsafe() *Stmt {
} }
// Select using the prepared statement. // Select using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) Select(dest interface{}, args ...interface{}) error { func (s *Stmt) Select(dest interface{}, args ...interface{}) error {
return Select(&qStmt{s}, dest, "", args...) return Select(&qStmt{s}, dest, "", args...)
} }
// Get using the prepared statement. // Get using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (s *Stmt) Get(dest interface{}, args ...interface{}) error { func (s *Stmt) Get(dest interface{}, args ...interface{}) error {
return Get(&qStmt{s}, dest, "", args...) return Get(&qStmt{s}, dest, "", args...)
} }
// MustExec (panic) using this statement. Note that the query portion of the error // MustExec (panic) using this statement. Note that the query portion of the error
// output will be blank, as Stmt does not expose its query. // output will be blank, as Stmt does not expose its query.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) MustExec(args ...interface{}) sql.Result { func (s *Stmt) MustExec(args ...interface{}) sql.Result {
return MustExec(&qStmt{s}, "", args...) return MustExec(&qStmt{s}, "", args...)
} }
// QueryRowx using this statement. // QueryRowx using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) QueryRowx(args ...interface{}) *Row { func (s *Stmt) QueryRowx(args ...interface{}) *Row {
qs := &qStmt{s} qs := &qStmt{s}
return qs.QueryRowx("", args...) return qs.QueryRowx("", args...)
} }
// Queryx using this statement. // Queryx using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) {
qs := &qStmt{s} qs := &qStmt{s}
return qs.Queryx("", args...) return qs.Queryx("", args...)
@ -576,7 +605,7 @@ func (r *Rows) StructScan(dest interface{}) error {
r.fields = m.TraversalsByName(v.Type(), columns) r.fields = m.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error // if we are not unsafe and are missing fields, return an error
if f, err := missingFields(r.fields); err != nil && !r.unsafe { if f, err := missingFields(r.fields); err != nil && !r.unsafe {
return fmt.Errorf("missing destination name %s", columns[f]) return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
} }
r.values = make([]interface{}, len(columns)) r.values = make([]interface{}, len(columns))
r.started = true r.started = true
@ -626,6 +655,7 @@ func Preparex(p Preparer, query string) (*Stmt, error) {
// into dest, which must be a slice. If the slice elements are scannable, then // into dest, which must be a slice. If the slice elements are scannable, then
// the result set must have only one column. Otherwise, StructScan is used. // the result set must have only one column. Otherwise, StructScan is used.
// The *sql.Rows are closed automatically. // The *sql.Rows are closed automatically.
// Any placeholder parameters are replaced with supplied args.
func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { func Select(q Queryer, dest interface{}, query string, args ...interface{}) error {
rows, err := q.Queryx(query, args...) rows, err := q.Queryx(query, args...)
if err != nil { if err != nil {
@ -639,6 +669,8 @@ func Select(q Queryer, dest interface{}, query string, args ...interface{}) erro
// Get does a QueryRow using the provided Queryer, and scans the resulting row // Get does a QueryRow using the provided Queryer, and scans the resulting row
// to dest. If dest is scannable, the result must only have one column. Otherwise, // to dest. If dest is scannable, the result must only have one column. Otherwise,
// StructScan is used. Get will return sql.ErrNoRows like row.Scan would. // StructScan is used. Get will return sql.ErrNoRows like row.Scan would.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { func Get(q Queryer, dest interface{}, query string, args ...interface{}) error {
r := q.QueryRowx(query, args...) r := q.QueryRowx(query, args...)
return r.scanAny(dest, false) return r.scanAny(dest, false)
@ -669,6 +701,7 @@ func LoadFile(e Execer, path string) (*sql.Result, error) {
} }
// MustExec execs the query using e and panics if there was an error. // MustExec execs the query using e and panics if there was an error.
// Any placeholder parameters are replaced with supplied args.
func MustExec(e Execer, query string, args ...interface{}) sql.Result { func MustExec(e Execer, query string, args ...interface{}) sql.Result {
res, err := e.Exec(query, args...) res, err := e.Exec(query, args...)
if err != nil { if err != nil {
@ -691,6 +724,10 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {
if r.err != nil { if r.err != nil {
return r.err return r.err
} }
if r.rows == nil {
r.err = sql.ErrNoRows
return r.err
}
defer r.rows.Close() defer r.rows.Close()
v := reflect.ValueOf(dest) v := reflect.ValueOf(dest)
@ -726,7 +763,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {
fields := m.TraversalsByName(v.Type(), columns) fields := m.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error // if we are not unsafe and are missing fields, return an error
if f, err := missingFields(fields); err != nil && !r.unsafe { if f, err := missingFields(fields); err != nil && !r.unsafe {
return fmt.Errorf("missing destination name %s", columns[f]) return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
} }
values := make([]interface{}, len(columns)) values := make([]interface{}, len(columns))
@ -779,7 +816,7 @@ func SliceScan(r ColScanner) ([]interface{}, error) {
// executes SQL from input). Please do not use this as a primary interface! // executes SQL from input). Please do not use this as a primary interface!
// This will modify the map sent to it in place, so reuse the same map with // This will modify the map sent to it in place, so reuse the same map with
// care. Columns which occur more than once in the result will overwrite // care. Columns which occur more than once in the result will overwrite
// eachother! // each other!
func MapScan(r ColScanner, dest map[string]interface{}) error { func MapScan(r ColScanner, dest map[string]interface{}) error {
// ignore r.started, since we needn't use reflect for anything. // ignore r.started, since we needn't use reflect for anything.
columns, err := r.Columns() columns, err := r.Columns()
@ -892,7 +929,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
fields := m.TraversalsByName(base, columns) fields := m.TraversalsByName(base, columns)
// if we are not unsafe and are missing fields, return an error // if we are not unsafe and are missing fields, return an error
if f, err := missingFields(fields); err != nil && !isUnsafe(rows) { if f, err := missingFields(fields); err != nil && !isUnsafe(rows) {
return fmt.Errorf("missing destination name %s", columns[f]) return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
} }
values = make([]interface{}, len(columns)) values = make([]interface{}, len(columns))
@ -902,6 +939,9 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
v = reflect.Indirect(vp) v = reflect.Indirect(vp)
err = fieldsByTraversal(v, fields, values, true) err = fieldsByTraversal(v, fields, values, true)
if err != nil {
return err
}
// scan into the struct field pointers and append to our results // scan into the struct field pointers and append to our results
err = rows.Scan(values...) err = rows.Scan(values...)
@ -919,6 +959,9 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
for rows.Next() { for rows.Next() {
vp = reflect.New(base) vp = reflect.New(base)
err = rows.Scan(vp.Interface()) err = rows.Scan(vp.Interface())
if err != nil {
return err
}
// append // append
if isPtr { if isPtr {
direct.Set(reflect.Append(direct, vp)) direct.Set(reflect.Append(direct, vp))
@ -937,7 +980,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
// anyway) works on a rows object. // anyway) works on a rows object.
// StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice. // StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice.
// StructScan will scan in the entire rows result, so if you need do not want to // StructScan will scan in the entire rows result, so if you do not want to
// allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan. // allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan.
// If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default. // If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default.
func StructScan(rows rowsi, dest interface{}) error { func StructScan(rows rowsi, dest interface{}) error {

335
vendor/github.com/jmoiron/sqlx/sqlx_context.go generated vendored Normal file
View file

@ -0,0 +1,335 @@
// +build go1.8
package sqlx
import (
"context"
"database/sql"
"fmt"
"io/ioutil"
"path/filepath"
"reflect"
)
// ConnectContext to a database and verify with a ping.
func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) {
db, err := Open(driverName, dataSourceName)
if err != nil {
return db, err
}
err = db.PingContext(ctx)
return db, err
}
// QueryerContext is an interface used by GetContext and SelectContext
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)
QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row
}
// PreparerContext is an interface used by PreparexContext.
type PreparerContext interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// ExecerContext is an interface used by MustExecContext and LoadFileContext
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
// ExtContext is a union interface which can bind, query, and exec, with Context
// used by NamedQueryContext and NamedExecContext.
type ExtContext interface {
binder
QueryerContext
ExecerContext
}
// SelectContext executes a query using the provided Queryer, and StructScans
// each row into dest, which must be a slice. If the slice elements are
// scannable, then the result set must have only one column. Otherwise,
// StructScan is used. The *sql.Rows are closed automatically.
// Any placeholder parameters are replaced with supplied args.
func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error {
rows, err := q.QueryxContext(ctx, query, args...)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// PreparexContext prepares a statement.
//
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) {
s, err := p.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err
}
// GetContext does a QueryRow using the provided Queryer, and scans the
// resulting row to dest. If dest is scannable, the result must only have one
// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like
// row.Scan would. Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error {
r := q.QueryRowxContext(ctx, query, args...)
return r.scanAny(dest, false)
}
// LoadFileContext exec's every statement in a file (as a single call to Exec).
// LoadFileContext may return a nil *sql.Result if errors are encountered
// locating or reading the file at path. LoadFile reads the entire file into
// memory, so it is not suitable for loading large data dumps, but can be useful
// for initializing schemas or loading indexes.
//
// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3
// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting
// this by requiring something with DriverName() and then attempting to split the
// queries will be difficult to get right, and its current driver-specific behavior
// is deemed at least not complex in its incorrectness.
func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) {
realpath, err := filepath.Abs(path)
if err != nil {
return nil, err
}
contents, err := ioutil.ReadFile(realpath)
if err != nil {
return nil, err
}
res, err := e.ExecContext(ctx, string(contents))
return &res, err
}
// MustExecContext execs the query using e and panics if there was an error.
// Any placeholder parameters are replaced with supplied args.
func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result {
res, err := e.ExecContext(ctx, query, args...)
if err != nil {
panic(err)
}
return res
}
// PrepareNamedContext returns an sqlx.NamedStmt
func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) {
return prepareNamedContext(ctx, db, query)
}
// NamedQueryContext using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) {
return NamedQueryContext(ctx, db, query, arg)
}
// NamedExecContext using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
return NamedExecContext(ctx, db, query, arg)
}
// SelectContext using this DB.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return SelectContext(ctx, db, dest, query, args...)
}
// GetContext using this DB.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return GetContext(ctx, db, dest, query, args...)
}
// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
//
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) {
return PreparexContext(ctx, db, query)
}
// QueryxContext queries the database and returns an *sqlx.Rows.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := db.DB.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err
}
// QueryRowxContext queries the database and returns an *sqlx.Row.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := db.DB.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper}
}
// MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead
// of an *sql.Tx.
//
// The provided context is used until the transaction is committed or rolled
// back. If the context is canceled, the sql package will roll back the
// transaction. Tx.Commit will return an error if the context provided to
// MustBeginContext is canceled.
func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx {
tx, err := db.BeginTxx(ctx, opts)
if err != nil {
panic(err)
}
return tx
}
// MustExecContext (panic) runs MustExec using this database.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result {
return MustExecContext(ctx, db, query, args...)
}
// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an
// *sql.Tx.
//
// The provided context is used until the transaction is committed or rolled
// back. If the context is canceled, the sql package will roll back the
// transaction. Tx.Commit will return an error if the context provided to
// BeginxContext is canceled.
func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err
}
// StmtxContext returns a version of the prepared statement which runs within a
// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt.
func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt {
var s *sql.Stmt
switch v := stmt.(type) {
case Stmt:
s = v.Stmt
case *Stmt:
s = v.Stmt
case sql.Stmt:
s = &v
case *sql.Stmt:
s = v
default:
panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type()))
}
return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper}
}
// NamedStmtContext returns a version of the prepared statement which runs
// within a transaction.
func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt {
return &NamedStmt{
QueryString: stmt.QueryString,
Params: stmt.Params,
Stmt: tx.StmtxContext(ctx, stmt.Stmt),
}
}
// MustExecContext runs MustExecContext within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result {
return MustExecContext(ctx, tx, query, args...)
}
// QueryxContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := tx.Tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err
}
// SelectContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return SelectContext(ctx, tx, dest, query, args...)
}
// GetContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return GetContext(ctx, tx, dest, query, args...)
}
// QueryRowxContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := tx.Tx.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper}
}
// NamedExecContext using this Tx.
// Any named placeholder parameters are replaced with fields from arg.
func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
return NamedExecContext(ctx, tx, query, arg)
}
// SelectContext using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error {
return SelectContext(ctx, &qStmt{s}, dest, "", args...)
}
// GetContext using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error {
return GetContext(ctx, &qStmt{s}, dest, "", args...)
}
// MustExecContext (panic) using this statement. Note that the query portion of
// the error output will be blank, as Stmt does not expose its query.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result {
return MustExecContext(ctx, &qStmt{s}, "", args...)
}
// QueryRowxContext using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row {
qs := &qStmt{s}
return qs.QueryRowxContext(ctx, "", args...)
}
// QueryxContext using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) {
qs := &qStmt{s}
return qs.QueryxContext(ctx, "", args...)
}
func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return q.Stmt.QueryContext(ctx, args...)
}
func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := q.Stmt.QueryContext(ctx, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err
}
func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := q.Stmt.QueryContext(ctx, args...)
return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}
}
func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return q.Stmt.ExecContext(ctx, args...)
}

1344
vendor/github.com/jmoiron/sqlx/sqlx_context_test.go generated vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -591,11 +591,21 @@ func TestNilReceiver(t *testing.T) {
func TestNamedQuery(t *testing.T) { func TestNamedQuery(t *testing.T) {
var schema = Schema{ var schema = Schema{
create: ` create: `
CREATE TABLE place (
id integer PRIMARY KEY,
name text NULL
);
CREATE TABLE person ( CREATE TABLE person (
first_name text NULL, first_name text NULL,
last_name text NULL, last_name text NULL,
email text NULL email text NULL
); );
CREATE TABLE placeperson (
first_name text NULL,
last_name text NULL,
email text NULL,
place_id integer NULL
);
CREATE TABLE jsperson ( CREATE TABLE jsperson (
"FIRST" text NULL, "FIRST" text NULL,
last_name text NULL, last_name text NULL,
@ -604,6 +614,8 @@ func TestNamedQuery(t *testing.T) {
drop: ` drop: `
drop table person; drop table person;
drop table jsperson; drop table jsperson;
drop table place;
drop table placeperson;
`, `,
} }
@ -734,6 +746,76 @@ func TestNamedQuery(t *testing.T) {
db.Mapper = &old db.Mapper = &old
// Test nested structs
type Place struct {
ID int `db:"id"`
Name sql.NullString `db:"name"`
}
type PlacePerson struct {
FirstName sql.NullString `db:"first_name"`
LastName sql.NullString `db:"last_name"`
Email sql.NullString
Place Place `db:"place"`
}
pl := Place{
Name: sql.NullString{String: "myplace", Valid: true},
}
pp := PlacePerson{
FirstName: sql.NullString{String: "ben", Valid: true},
LastName: sql.NullString{String: "doe", Valid: true},
Email: sql.NullString{String: "ben@doe.com", Valid: true},
}
q2 := `INSERT INTO place (id, name) VALUES (1, :name)`
_, err = db.NamedExec(q2, pl)
if err != nil {
log.Fatal(err)
}
id := 1
pp.Place.ID = id
q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`
_, err = db.NamedExec(q3, pp)
if err != nil {
log.Fatal(err)
}
pp2 := &PlacePerson{}
rows, err = db.NamedQuery(`
SELECT
first_name,
last_name,
email,
place.id AS "place.id",
place.name AS "place.name"
FROM placeperson
INNER JOIN place ON place.id = placeperson.place_id
WHERE
place.id=:place.id`, pp)
if err != nil {
log.Fatal(err)
}
for rows.Next() {
err = rows.StructScan(pp2)
if err != nil {
t.Error(err)
}
if pp2.FirstName.String != "ben" {
t.Error("Expected first name of `ben`, got " + pp2.FirstName.String)
}
if pp2.LastName.String != "doe" {
t.Error("Expected first name of `doe`, got " + pp2.LastName.String)
}
if pp2.Place.Name.String != "myplace" {
t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String)
}
if pp2.Place.ID != pp.Place.ID {
t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID)
}
}
}) })
} }
@ -885,6 +967,9 @@ func TestUsage(t *testing.T) {
t.Error("Expected an error") t.Error("Expected an error")
} }
err = stmt1.Get(&jason, "DoesNotExist User 2") err = stmt1.Get(&jason, "DoesNotExist User 2")
if err == nil {
t.Fatal(err)
}
stmt2, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) stmt2, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?"))
if err != nil { if err != nil {
@ -905,6 +990,10 @@ func TestUsage(t *testing.T) {
places := []*Place{} places := []*Place{}
err = db.Select(&places, "SELECT telcode FROM place ORDER BY telcode ASC") err = db.Select(&places, "SELECT telcode FROM place ORDER BY telcode ASC")
if err != nil {
t.Fatal(err)
}
usa, singsing, honkers := places[0], places[1], places[2] usa, singsing, honkers := places[0], places[1], places[2]
if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 {
@ -922,6 +1011,10 @@ func TestUsage(t *testing.T) {
// this test also verifies that you can use either a []Struct{} or a []*Struct{} // this test also verifies that you can use either a []Struct{} or a []*Struct{}
places2 := []Place{} places2 := []Place{}
err = db.Select(&places2, "SELECT * FROM place ORDER BY telcode ASC") err = db.Select(&places2, "SELECT * FROM place ORDER BY telcode ASC")
if err != nil {
t.Fatal(err)
}
usa, singsing, honkers = &places2[0], &places2[1], &places2[2] usa, singsing, honkers = &places2[0], &places2[1], &places2[2]
// this should return a type error that &p is not a pointer to a struct slice // this should return a type error that &p is not a pointer to a struct slice
@ -1276,8 +1369,9 @@ func TestBindMap(t *testing.T) {
type Message struct { type Message struct {
Text string `db:"string"` Text string `db:"string"`
Properties PropertyMap // Stored as JSON in the database Properties PropertyMap `db:"properties"` // Stored as JSON in the database
} }
type PropertyMap map[string]string type PropertyMap map[string]string
// Implement driver.Valuer and sql.Scanner interfaces on PropertyMap // Implement driver.Valuer and sql.Scanner interfaces on PropertyMap
@ -1314,7 +1408,7 @@ func TestEmbeddedMaps(t *testing.T) {
{"Hello, World", PropertyMap{"one": "1", "two": "2"}}, {"Hello, World", PropertyMap{"one": "1", "two": "2"}},
{"Thanks, Joy", PropertyMap{"pull": "request"}}, {"Thanks, Joy", PropertyMap{"pull": "request"}},
} }
q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties)` q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);`
for _, m := range messages { for _, m := range messages {
_, err := db.NamedExec(q1, m) _, err := db.NamedExec(q1, m)
if err != nil { if err != nil {
@ -1324,19 +1418,19 @@ func TestEmbeddedMaps(t *testing.T) {
var count int var count int
err := db.Get(&count, "SELECT count(*) FROM message") err := db.Get(&count, "SELECT count(*) FROM message")
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
} }
if count != len(messages) { if count != len(messages) {
t.Errorf("Expected %d messages in DB, found %d", len(messages), count) t.Fatalf("Expected %d messages in DB, found %d", len(messages), count)
} }
var m Message var m Message
err = db.Get(&m, "SELECT * FROM message LIMIT 1") err = db.Get(&m, "SELECT * FROM message LIMIT 1;")
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
} }
if m.Properties == nil { if m.Properties == nil {
t.Error("Expected m.Properties to not be nil, but it was.") t.Fatal("Expected m.Properties to not be nil, but it was.")
} }
}) })
} }
@ -1359,31 +1453,25 @@ func TestIssue197(t *testing.T) {
if err = db.Get(&v, `SELECT '{"a": "b"}' AS raw`); err != nil { if err = db.Get(&v, `SELECT '{"a": "b"}' AS raw`); err != nil {
t.Fatal(err) t.Fatal(err)
} }
fmt.Printf("%s: v %s\n", db.DriverName(), v.Raw)
if err = db.Get(&q, `SELECT 'null' AS raw`); err != nil { if err = db.Get(&q, `SELECT 'null' AS raw`); err != nil {
t.Fatal(err) t.Fatal(err)
} }
fmt.Printf("%s: v %s\n", db.DriverName(), v.Raw)
var v2, q2 Var2 var v2, q2 Var2
if err = db.Get(&v2, `SELECT '{"a": "b"}' AS raw`); err != nil { if err = db.Get(&v2, `SELECT '{"a": "b"}' AS raw`); err != nil {
t.Fatal(err) t.Fatal(err)
} }
fmt.Printf("%s: v2 %s\n", db.DriverName(), v2.Raw)
if err = db.Get(&q2, `SELECT 'null' AS raw`); err != nil { if err = db.Get(&q2, `SELECT 'null' AS raw`); err != nil {
t.Fatal(err) t.Fatal(err)
} }
fmt.Printf("%s: v2 %s\n", db.DriverName(), v2.Raw)
var v3, q3 Var3 var v3, q3 Var3
if err = db.QueryRow(`SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { if err = db.QueryRow(`SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil {
t.Fatal(err) t.Fatal(err)
} }
fmt.Printf("v3 %s\n", v3.Raw)
if err = db.QueryRow(`SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { if err = db.QueryRow(`SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil {
t.Fatal(err) t.Fatal(err)
} }
fmt.Printf("v3 %s\n", v3.Raw)
t.Fail() t.Fail()
}) })
} }
@ -1649,6 +1737,36 @@ func BenchmarkIn(b *testing.B) {
} }
} }
func BenchmarkIn1k(b *testing.B) {
q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`
var vals [1000]interface{}
for i := 0; i < b.N; i++ {
_, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...)
}
}
func BenchmarkIn1kInt(b *testing.B) {
q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`
var vals [1000]int
for i := 0; i < b.N; i++ {
_, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...)
}
}
func BenchmarkIn1kString(b *testing.B) {
q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`
var vals [1000]string
for i := 0; i < b.N; i++ {
_, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...)
}
}
func BenchmarkRebind(b *testing.B) { func BenchmarkRebind(b *testing.B) {
b.StopTimer() b.StopTimer()
q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`

View file

@ -39,6 +39,9 @@ func (g *GzippedText) Scan(src interface{}) error {
return errors.New("Incompatible type for GzippedText") return errors.New("Incompatible type for GzippedText")
} }
reader, err := gzip.NewReader(bytes.NewReader(source)) reader, err := gzip.NewReader(bytes.NewReader(source))
if err != nil {
return err
}
defer reader.Close() defer reader.Close()
b, err := ioutil.ReadAll(reader) b, err := ioutil.ReadAll(reader)
if err != nil { if err != nil {
@ -54,9 +57,14 @@ func (g *GzippedText) Scan(src interface{}) error {
// implements `Unmarshal`, which unmarshals the json within to an interface{} // implements `Unmarshal`, which unmarshals the json within to an interface{}
type JSONText json.RawMessage type JSONText json.RawMessage
var emptyJSON = JSONText("{}")
// MarshalJSON returns the *j as the JSON encoding of j. // MarshalJSON returns the *j as the JSON encoding of j.
func (j *JSONText) MarshalJSON() ([]byte, error) { func (j JSONText) MarshalJSON() ([]byte, error) {
return *j, nil if len(j) == 0 {
return emptyJSON, nil
}
return j, nil
} }
// UnmarshalJSON sets *j to a copy of data // UnmarshalJSON sets *j to a copy of data
@ -66,7 +74,6 @@ func (j *JSONText) UnmarshalJSON(data []byte) error {
} }
*j = append((*j)[0:0], data...) *j = append((*j)[0:0], data...)
return nil return nil
} }
// Value returns j as a value. This does a validating unmarshal into another // Value returns j as a value. This does a validating unmarshal into another
@ -83,11 +90,17 @@ func (j JSONText) Value() (driver.Value, error) {
// Scan stores the src in *j. No validation is done. // Scan stores the src in *j. No validation is done.
func (j *JSONText) Scan(src interface{}) error { func (j *JSONText) Scan(src interface{}) error {
var source []byte var source []byte
switch src.(type) { switch t := src.(type) {
case string: case string:
source = []byte(src.(string)) source = []byte(t)
case []byte: case []byte:
source = src.([]byte) if len(t) == 0 {
source = emptyJSON
} else {
source = t
}
case nil:
*j = emptyJSON
default: default:
return errors.New("Incompatible type for JSONText") return errors.New("Incompatible type for JSONText")
} }
@ -97,10 +110,63 @@ func (j *JSONText) Scan(src interface{}) error {
// Unmarshal unmarshal's the json in j to v, as in json.Unmarshal. // Unmarshal unmarshal's the json in j to v, as in json.Unmarshal.
func (j *JSONText) Unmarshal(v interface{}) error { func (j *JSONText) Unmarshal(v interface{}) error {
if len(*j) == 0 {
*j = emptyJSON
}
return json.Unmarshal([]byte(*j), v) return json.Unmarshal([]byte(*j), v)
} }
// Pretty printing for JSONText types // String supports pretty printing for JSONText types.
func (j JSONText) String() string { func (j JSONText) String() string {
return string(j) return string(j)
} }
// NullJSONText represents a JSONText that may be null.
// NullJSONText implements the scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullJSONText struct {
JSONText
Valid bool // Valid is true if JSONText is not NULL
}
// Scan implements the Scanner interface.
func (n *NullJSONText) Scan(value interface{}) error {
if value == nil {
n.JSONText, n.Valid = emptyJSON, false
return nil
}
n.Valid = true
return n.JSONText.Scan(value)
}
// Value implements the driver Valuer interface.
func (n NullJSONText) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.JSONText.Value()
}
// BitBool is an implementation of a bool for the MySQL type BIT(1).
// This type allows you to avoid wasting an entire byte for MySQL's boolean type TINYINT.
type BitBool bool
// Value implements the driver.Valuer interface,
// and turns the BitBool into a bitfield (BIT(1)) for MySQL storage.
func (b BitBool) Value() (driver.Value, error) {
if b {
return []byte{1}, nil
}
return []byte{0}, nil
}
// Scan implements the sql.Scanner interface,
// and turns the bitfield incoming from MySQL into a BitBool
func (b *BitBool) Scan(src interface{}) error {
v, ok := src.([]byte)
if !ok {
return errors.New("bad []byte type assertion")
}
*b = v[0] == 1
return nil
}

View file

@ -39,4 +39,89 @@ func TestJSONText(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("Was expecting invalid json to fail!") t.Errorf("Was expecting invalid json to fail!")
} }
j = JSONText("")
v, err = j.Value()
if err != nil {
t.Errorf("Was not expecting an error")
}
err = (&j).Scan(v)
if err != nil {
t.Errorf("Was not expecting an error")
}
j = JSONText(nil)
v, err = j.Value()
if err != nil {
t.Errorf("Was not expecting an error")
}
err = (&j).Scan(v)
if err != nil {
t.Errorf("Was not expecting an error")
}
}
func TestNullJSONText(t *testing.T) {
j := NullJSONText{}
err := j.Scan(`{"foo": 1, "bar": 2}`)
if err != nil {
t.Errorf("Was not expecting an error")
}
v, err := j.Value()
if err != nil {
t.Errorf("Was not expecting an error")
}
err = (&j).Scan(v)
if err != nil {
t.Errorf("Was not expecting an error")
}
m := map[string]interface{}{}
j.Unmarshal(&m)
if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 {
t.Errorf("Expected valid json but got some garbage instead? %#v", m)
}
j = NullJSONText{}
err = j.Scan(nil)
if err != nil {
t.Errorf("Was not expecting an error")
}
if j.Valid != false {
t.Errorf("Expected valid to be false, but got true")
}
}
func TestBitBool(t *testing.T) {
// Test true value
var b BitBool = true
v, err := b.Value()
if err != nil {
t.Errorf("Cannot return error")
}
err = (&b).Scan(v)
if err != nil {
t.Errorf("Was not expecting an error")
}
if !b {
t.Errorf("Was expecting the bool we sent in (true), got %v", b)
}
// Test false value
b = false
v, err = b.Value()
if err != nil {
t.Errorf("Cannot return error")
}
err = (&b).Scan(v)
if err != nil {
t.Errorf("Was not expecting an error")
}
if b {
t.Errorf("Was expecting the bool we sent in (false), got %v", b)
}
} }