diff --git a/core/api/request/page.go b/core/api/request/page.go index f70c20a7..7d91d489 100644 --- a/core/api/request/page.go +++ b/core/api/request/page.go @@ -16,11 +16,12 @@ import ( "strings" "time" + "golang.org/x/net/html" + "github.com/documize/community/core/api/endpoint/models" "github.com/documize/community/core/api/entity" "github.com/documize/community/core/log" "github.com/documize/community/core/streamutil" - "github.com/documize/community/domain/link" "github.com/jmoiron/sqlx" ) @@ -249,7 +250,7 @@ func (p *Persister) UpdatePage(page entity.Page, refID, userID string, skipRevis } // find any content links in the HTML - links := link.GetContentLinks(page.Body) + links := GetContentLinks(page.Body) // get a copy of previously saved links previousLinks, _ := p.GetPageLinks(page.DocumentID, page.RefID) @@ -497,3 +498,58 @@ func (p *Persister) GetNextPageSequence(documentID string) (maxSeq float64, err return } + +// GetContentLinks returns Documize generated links. +// such links have an identifying attribute e.g. tag + isAnchor := t.Data == "a" + if !isAnchor { + continue + } + + // Extract the content link + ok, link := getLink(t) + if ok { + links = append(links, link) + } + } + } +} + +// Helper function to pull the href attribute from a Token +func getLink(t html.Token) (ok bool, link entity.Link) { + ok = false + + // Iterate over all of the Token's attributes until we find an "href" + for _, a := range t.Attr { + switch a.Key { + case "data-documize": + ok = true + case "data-link-id": + link.RefID = strings.TrimSpace(a.Val) + case "data-link-space-id": + link.FolderID = strings.TrimSpace(a.Val) + case "data-link-target-document-id": + link.TargetDocumentID = strings.TrimSpace(a.Val) + case "data-link-target-id": + link.TargetID = strings.TrimSpace(a.Val) + case "data-link-type": + link.LinkType = strings.TrimSpace(a.Val) + } + } + + return +} diff --git a/domain/attachment/endpoint.go b/domain/attachment/endpoint.go new file mode 100644 index 00000000..99b792df --- /dev/null +++ b/domain/attachment/endpoint.go @@ -0,0 +1,224 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package attachment + +import ( + "bytes" + "database/sql" + "fmt" + "io" + "mime" + "net/http" + + "github.com/documize/community/core/env" + "github.com/documize/community/core/request" + "github.com/documize/community/core/response" + "github.com/documize/community/core/secrets" + "github.com/documize/community/core/uniqueid" + "github.com/documize/community/domain" + "github.com/documize/community/domain/document" + "github.com/documize/community/model/attachment" + "github.com/documize/community/model/audit" + uuid "github.com/nu7hatch/gouuid" +) + +// Handler contains the runtime information such as logging and database. +type Handler struct { + Runtime *env.Runtime + Store *domain.Store +} + +// Download is the end-point that responds to a request for a particular attachment +// by sending the requested file to the client. +func (h *Handler) Download(w http.ResponseWriter, r *http.Request) { + method := "attachment.Download" + ctx := domain.GetRequestContext(r) + + a, err := h.Store.Attachment.GetAttachment(ctx, request.Param(r, "orgID"), request.Param(r, "attachmentID")) + + if err == sql.ErrNoRows { + response.WriteNotFoundError(w, method, request.Param(r, "fileID")) + return + } + if err != nil { + response.WriteServerError(w, method, err) + return + } + + typ := mime.TypeByExtension("." + a.Extension) + if typ == "" { + typ = "application/octet-stream" + } + + w.Header().Set("Content-Type", typ) + w.Header().Set("Content-Disposition", `Attachment; filename="`+a.Filename+`" ; `+`filename*="`+a.Filename+`"`) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(a.Data))) + w.WriteHeader(http.StatusOK) + + _, err = w.Write(a.Data) + if err != nil { + h.Runtime.Log.Error("writing attachment", err) + return + } + + h.Store.Audit.Record(ctx, audit.EventTypeAttachmentDownload) +} + +// Get is an end-point that returns all of the attachments of a particular documentID. +func (h *Handler) Get(w http.ResponseWriter, r *http.Request) { + method := "attachment.GetAttachments" + ctx := domain.GetRequestContext(r) + + documentID := request.Param(r, "documentID") + if len(documentID) == 0 { + response.WriteMissingDataError(w, method, "documentID") + return + } + + if !document.CanViewDocument(ctx, *h.Store, documentID) { + response.WriteForbiddenError(w) + return + } + + a, err := h.Store.Attachment.GetAttachments(ctx, documentID) + if err != nil && err != sql.ErrNoRows { + response.WriteServerError(w, method, err) + return + } + + if len(a) == 0 { + a = []attachment.Attachment{} + } + + response.WriteJSON(w, a) +} + +// Delete is an endpoint that deletes a particular document attachment. +func (h *Handler) Delete(w http.ResponseWriter, r *http.Request) { + method := "attachment.DeleteAttachment" + ctx := domain.GetRequestContext(r) + + documentID := request.Param(r, "documentID") + if len(documentID) == 0 { + response.WriteMissingDataError(w, method, "documentID") + return + } + + attachmentID := request.Param(r, "attachmentID") + if len(attachmentID) == 0 { + response.WriteMissingDataError(w, method, "attachmentID") + return + } + + if !document.CanChangeDocument(ctx, *h.Store, documentID) { + response.WriteForbiddenError(w) + return + } + + var err error + ctx.Transaction, err = h.Runtime.Db.Beginx() + if err != nil { + response.WriteServerError(w, method, err) + return + } + + _, err = h.Store.Attachment.Delete(ctx, attachmentID) + if err != nil { + ctx.Transaction.Rollback() + response.WriteServerError(w, method, err) + return + } + + // Mark references to this document as orphaned + err = h.Store.Link.MarkOrphanAttachmentLink(ctx, attachmentID) + if err != nil { + ctx.Transaction.Rollback() + response.WriteServerError(w, method, err) + return + } + + h.Store.Audit.Record(ctx, audit.EventTypeAttachmentDelete) + + ctx.Transaction.Commit() + + response.WriteEmpty(w) +} + +// Add stores files against a document. +func (h *Handler) Add(w http.ResponseWriter, r *http.Request) { + method := "attachment.Add" + ctx := domain.GetRequestContext(r) + + documentID := request.Param(r, "documentID") + if len(documentID) == 0 { + response.WriteMissingDataError(w, method, "documentID") + return + } + + if !document.CanChangeDocument(ctx, *h.Store, documentID) { + response.WriteForbiddenError(w) + return + } + + filedata, filename, err := r.FormFile("attachment") + + if err != nil { + response.WriteMissingDataError(w, method, "attachment") + return + } + + b := new(bytes.Buffer) + _, err = io.Copy(b, filedata) + if err != nil { + response.WriteServerError(w, method, err) + return + } + + var job = "some-uuid" + + newUUID, err := uuid.NewV4() + if err != nil { + response.WriteServerError(w, method, err) + return + } + + job = newUUID.String() + + var a attachment.Attachment + refID := uniqueid.Generate() + a.RefID = refID + a.DocumentID = documentID + a.Job = job + random := secrets.GenerateSalt() + a.FileID = random[0:9] + a.Filename = filename.Filename + a.Data = b.Bytes() + + ctx.Transaction, err = h.Runtime.Db.Beginx() + if err != nil { + response.WriteServerError(w, method, err) + return + } + + err = h.Store.Attachment.Add(ctx, a) + if err != nil { + ctx.Transaction.Rollback() + response.WriteServerError(w, method, err) + return + } + + h.Store.Audit.Record(ctx, audit.EventTypeAttachmentAdd) + + ctx.Transaction.Commit() + + response.WriteEmpty(w) +} diff --git a/domain/attachment/mysql/store.go b/domain/attachment/mysql/store.go new file mode 100644 index 00000000..42273cd9 --- /dev/null +++ b/domain/attachment/mysql/store.go @@ -0,0 +1,104 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package attachment + +import ( + "strings" + "time" + + "github.com/documize/community/domain" + "github.com/documize/community/domain/store/mysql" + "github.com/pkg/errors" + + "github.com/documize/community/core/env" + "github.com/documize/community/core/streamutil" + "github.com/documize/community/model/attachment" +) + +// Scope provides data access to MySQL. +type Scope struct { + Runtime *env.Runtime +} + +// Add inserts the given record into the database attachement table. +func (s Scope) Add(ctx domain.RequestContext, a attachment.Attachment) (err error) { + a.OrgID = ctx.OrgID + a.Created = time.Now().UTC() + a.Revised = time.Now().UTC() + bits := strings.Split(a.Filename, ".") + a.Extension = bits[len(bits)-1] + + stmt, err := ctx.Transaction.Preparex("INSERT INTO attachment (refid, orgid, documentid, job, fileid, filename, data, extension, created, revised) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") + defer streamutil.Close(stmt) + + if err != nil { + err = errors.Wrap(err, "prepare insert attachment") + return + } + + _, err = stmt.Exec(a.RefID, a.OrgID, a.DocumentID, a.Job, a.FileID, a.Filename, a.Data, a.Extension, a.Created, a.Revised) + if err != nil { + err = errors.Wrap(err, "execute insert attachment") + return + } + + return +} + +// GetAttachment returns the database attachment record specified by the parameters. +func (s Scope) GetAttachment(ctx domain.RequestContext, orgID, attachmentID string) (a attachment.Attachment, err error) { + stmt, err := s.Runtime.Db.Preparex("SELECT id, refid, orgid, documentid, job, fileid, filename, data, extension, created, revised FROM attachment WHERE orgid=? and refid=?") + defer streamutil.Close(stmt) + + if err != nil { + err = errors.Wrap(err, "prepare select attachment") + return + } + + err = stmt.Get(&a, orgID, attachmentID) + if err != nil { + err = errors.Wrap(err, "execute select attachment") + return + } + + return +} + +// GetAttachments returns a slice containing the attachement records (excluding their data) for document docID, ordered by filename. +func (s Scope) GetAttachments(ctx domain.RequestContext, docID string) (a []attachment.Attachment, err error) { + err = s.Runtime.Db.Select(&a, "SELECT id, refid, orgid, documentid, job, fileid, filename, extension, created, revised FROM attachment WHERE orgid=? and documentid=? order by filename", ctx.OrgID, docID) + + if err != nil { + err = errors.Wrap(err, "execute select attachments") + return + } + + return +} + +// GetAttachmentsWithData returns a slice containing the attachement records (including their data) for document docID, ordered by filename. +func (s Scope) GetAttachmentsWithData(ctx domain.RequestContext, docID string) (a []attachment.Attachment, err error) { + err = s.Runtime.Db.Select(&a, "SELECT id, refid, orgid, documentid, job, fileid, filename, extension, data, created, revised FROM attachment WHERE orgid=? and documentid=? order by filename", ctx.OrgID, docID) + + if err != nil { + err = errors.Wrap(err, "execute select attachments with data") + return + } + + return +} + +// Delete deletes the id record from the database attachment table. +func (s Scope) Delete(ctx domain.RequestContext, id string) (rows int64, err error) { + b := mysql.BaseQuery{} + return b.DeleteConstrained(ctx.Transaction, "attachment", ctx.OrgID, id) +} diff --git a/domain/auth/endpoint.go b/domain/auth/endpoint.go index 2dd070c6..b6fd496f 100644 --- a/domain/auth/endpoint.go +++ b/domain/auth/endpoint.go @@ -11,13 +11,34 @@ package auth -/* -// Authenticate user based up HTTP Authorization header. -// An encrypted authentication token is issued with an expiry date. -func (h *Handler) Authenticate(w http.ResponseWriter, r *http.Request) { - method := "Authenticate" +import ( + "database/sql" + "errors" + "net/http" + "strings" - s := domain.StoreContext{Runtime: h.Runtime, Context: domain.GetRequestContext(r)} + "github.com/documize/community/core/env" + "github.com/documize/community/core/response" + "github.com/documize/community/core/secrets" + "github.com/documize/community/domain" + "github.com/documize/community/domain/organization" + "github.com/documize/community/domain/section/provider" + "github.com/documize/community/domain/user" + "github.com/documize/community/model/auth" + "github.com/documize/community/model/org" +) + +// Handler contains the runtime information such as logging and database. +type Handler struct { + Runtime *env.Runtime + Store *domain.Store +} + +// Login user based up HTTP Authorization header. +// An encrypted authentication token is issued with an expiry date. +func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { + method := "auth.Login" + ctx := domain.GetRequestContext(r) // check for http header authHeader := r.Header.Get("Authorization") @@ -46,23 +67,20 @@ func (h *Handler) Authenticate(w http.ResponseWriter, r *http.Request) { } dom := strings.TrimSpace(strings.ToLower(credentials[0])) - dom = organization.CheckDomain(s, dom) // TODO optimize by removing this once js allows empty domains + dom = h.Store.Organization.CheckDomain(ctx, dom) // TODO optimize by removing this once js allows empty domains email := strings.TrimSpace(strings.ToLower(credentials[1])) password := credentials[2] h.Runtime.Log.Info("logon attempt " + email + " @ " + dom) - u, err := user.GetByDomain(s, dom, email) - + u, err := h.Store.User.GetByDomain(ctx, dom, email) if err == sql.ErrNoRows { response.WriteUnauthorizedError(w) return } - if err != nil { response.WriteServerError(w, method, err) return } - if len(u.Reset) > 0 || len(u.Password) == 0 { response.WriteUnauthorizedError(w) return @@ -74,31 +92,29 @@ func (h *Handler) Authenticate(w http.ResponseWriter, r *http.Request) { return } - org, err := organization.GetOrganizationByDomain(s, dom) + org, err := h.Store.Organization.GetOrganizationByDomain(ctx, dom) if err != nil { response.WriteUnauthorizedError(w) return } // Attach user accounts and work out permissions - user.AttachUserAccounts(s, org.RefID, &u) - - // active check + user.AttachUserAccounts(ctx, *h.Store, org.RefID, &u) if len(u.Accounts) == 0 { response.WriteUnauthorizedError(w) return } - authModel := AuthenticationModel{} + authModel := auth.AuthenticationModel{} authModel.Token = GenerateJWT(h.Runtime, u.RefID, org.RefID, dom) authModel.User = u response.WriteJSON(w, authModel) } -// ValidateAuthToken finds and validates authentication token. -func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) { +// ValidateToken finds and validates authentication token. +func (h *Handler) ValidateToken(w http.ResponseWriter, r *http.Request) { // TODO should this go after token validation? if s := r.URL.Query().Get("section"); s != "" { if err := provider.Callback(s, w, r); err != nil { @@ -109,40 +125,40 @@ func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) { return } - s := domain.StoreContext{Runtime: h.Runtime, Context: domain.GetRequestContext(r)} - token := FindJWT(r) rc, _, tokenErr := DecodeJWT(h.Runtime, token) - var org = organization.Organization{} + var org = org.Organization{} var err = errors.New("") // We always grab the org record regardless of token status. // Why? If bad token we might be OK to alow anonymous access // depending upon the domain in question. if len(rc.OrgID) == 0 { - org, err = organization.GetOrganizationByDomain(s, organization.GetRequestSubdomain(s, r)) + dom := organization.GetRequestSubdomain(r) + org, err = h.Store.Organization.GetOrganizationByDomain(rc, dom) } else { - org, err = organization.GetOrganization(s, rc.OrgID) + org, err = h.Store.Organization.GetOrganization(rc, rc.OrgID) } rc.Subdomain = org.Domain // Inability to find org record spells the end of this request. if err != nil { - w.WriteHeader(http.StatusUnauthorized) + response.WriteUnauthorizedError(w) return } // If we have bad auth token and the domain does not allow anon access if !org.AllowAnonymousAccess && tokenErr != nil { + response.WriteUnauthorizedError(w) return } - dom := organization.GetSubdomainFromHost(s, r) - dom2 := organization.GetRequestSubdomain(s, r) + dom := organization.GetSubdomainFromHost(r) + dom2 := organization.GetRequestSubdomain(r) if org.Domain != dom && org.Domain != dom2 { - w.WriteHeader(http.StatusUnauthorized) + response.WriteUnauthorizedError(w) return } @@ -152,7 +168,7 @@ func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) { // So you have a bad token if len(token) > 1 { if tokenErr != nil { - w.WriteHeader(http.StatusUnauthorized) + response.WriteUnauthorizedError(w) return } } else { @@ -170,18 +186,18 @@ func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) { rc.Editor = false rc.Global = false rc.AppURL = r.Host - rc.Subdomain = organization.GetSubdomainFromHost(s, r) + rc.Subdomain = organization.GetSubdomainFromHost(r) rc.SSL = r.TLS != nil // Fetch user permissions for this org if !rc.Authenticated { - w.WriteHeader(http.StatusUnauthorized) + response.WriteUnauthorizedError(w) return } - u, err := user.GetSecuredUser(s, org.RefID, rc.UserID) + u, err := user.GetSecuredUser(rc, *h.Store, org.RefID, rc.UserID) if err != nil { - w.WriteHeader(http.StatusUnauthorized) + response.WriteUnauthorizedError(w) return } @@ -190,6 +206,4 @@ func (h *Handler) ValidateAuthToken(w http.ResponseWriter, r *http.Request) { rc.Global = u.Global response.WriteJSON(w, u) - return } -*/ diff --git a/domain/auth/keycloak.go b/domain/auth/keycloak.go new file mode 100644 index 00000000..a881c9bf --- /dev/null +++ b/domain/auth/keycloak.go @@ -0,0 +1,49 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package auth + +import ( + "encoding/json" + + "github.com/documize/community/core/env" + "github.com/documize/community/model/auth" +) + +// StripAuthSecrets removes sensitive data from auth provider configuration +func StripAuthSecrets(r *env.Runtime, provider, config string) string { + switch provider { + case "documize": + return config + break + case "keycloak": + c := auth.KeycloakConfig{} + err := json.Unmarshal([]byte(config), &c) + if err != nil { + r.Log.Error("StripAuthSecrets", err) + return config + } + c.AdminPassword = "" + c.AdminUser = "" + c.PublicKey = "" + + j, err := json.Marshal(c) + if err != nil { + r.Log.Error("StripAuthSecrets", err) + return config + } + + return string(j) + break + } + + return config +} diff --git a/domain/document/mysql/store.go b/domain/document/mysql/store.go index 6bdb6b02..ac4a289d 100644 --- a/domain/document/mysql/store.go +++ b/domain/document/mysql/store.go @@ -17,6 +17,7 @@ import ( "github.com/documize/community/core/env" "github.com/documize/community/core/streamutil" "github.com/documize/community/domain" + "github.com/documize/community/model/doc" "github.com/pkg/errors" ) @@ -25,6 +26,25 @@ type Scope struct { Runtime *env.Runtime } +// Get fetches the document record with the given id fromt the document table and audits that it has been got. +func (s Scope) Get(ctx domain.RequestContext, id string) (document doc.Document, err error) { + stmt, err := s.Runtime.Db.Preparex("SELECT id, refid, orgid, labelid, userid, job, location, title, excerpt, slug, tags, template, layout, created, revised FROM document WHERE orgid=? and refid=?") + defer streamutil.Close(stmt) + + if err != nil { + err = errors.Wrap(err, "prepare select document") + return + } + + err = stmt.Get(&document, ctx.OrgID, id) + if err != nil { + err = errors.Wrap(err, "execute select document") + return + } + + return +} + // MoveDocumentSpace changes the label for client's organization's documents which have space "id", to "move". func (s Scope) MoveDocumentSpace(ctx domain.RequestContext, id, move string) (err error) { stmt, err := ctx.Transaction.Preparex("UPDATE document SET labelid=? WHERE orgid=? AND labelid=?") @@ -43,3 +63,20 @@ func (s Scope) MoveDocumentSpace(ctx domain.RequestContext, id, move string) (er return } + +// PublicDocuments returns a slice of SitemapDocument records, holding documents in folders of type 1 (entity.TemplateTypePublic). +func (s Scope) PublicDocuments(ctx domain.RequestContext, orgID string) (documents []doc.SitemapDocument, err error) { + err = s.Runtime.Db.Select(&documents, + `SELECT d.refid as documentid, d.title as document, d.revised as revised, l.refid as folderid, l.label as folder +FROM document d LEFT JOIN label l ON l.refid=d.labelid +WHERE d.orgid=? +AND l.type=1 +AND d.template=0`, orgID) + + if err != nil { + err = errors.Wrap(err, fmt.Sprintf("execute GetPublicDocuments for org %s%s", orgID)) + return + } + + return +} diff --git a/domain/document/permission.go b/domain/document/permission.go new file mode 100644 index 00000000..5d8619bf --- /dev/null +++ b/domain/document/permission.go @@ -0,0 +1,116 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package document + +import ( + "database/sql" + + "github.com/documize/community/domain" +) + +// CanViewDocumentInFolder returns if the user has permission to view a document within the specified folder. +func CanViewDocumentInFolder(ctx domain.RequestContext, s domain.Store, labelID string) (hasPermission bool) { + roles, err := s.Space.GetUserRoles(ctx) + + if err == sql.ErrNoRows { + err = nil + } + if err != nil { + return false + } + + for _, role := range roles { + if role.LabelID == labelID && (role.CanView || role.CanEdit) { + return true + } + } + + return false +} + +// CanViewDocument returns if the clinet has permission to view a given document. +func CanViewDocument(ctx domain.RequestContext, s domain.Store, documentID string) (hasPermission bool) { + document, err := s.Document.Get(ctx, documentID) + + if err == sql.ErrNoRows { + err = nil + } + if err != nil { + return false + } + + roles, err := s.Space.GetUserRoles(ctx) + + if err == sql.ErrNoRows { + err = nil + } + if err != nil { + return false + } + + for _, role := range roles { + if role.LabelID == document.LabelID && (role.CanView || role.CanEdit) { + return true + } + } + + return false +} + +// CanChangeDocument returns if the clinet has permission to change a given document. +func CanChangeDocument(ctx domain.RequestContext, s domain.Store, documentID string) (hasPermission bool) { + document, err := s.Document.Get(ctx, documentID) + + if err == sql.ErrNoRows { + err = nil + } + if err != nil { + return false + } + + roles, err := s.Space.GetUserRoles(ctx) + + if err == sql.ErrNoRows { + err = nil + } + if err != nil { + return false + } + + for _, role := range roles { + if role.LabelID == document.LabelID && role.CanEdit { + return true + } + } + + return false +} + +// CanUploadDocument returns if the client has permission to upload documents to the given folderID. +func CanUploadDocument(ctx domain.RequestContext, s domain.Store, folderID string) (hasPermission bool) { + roles, err := s.Space.GetUserRoles(ctx) + + if err == sql.ErrNoRows { + err = nil + } + if err != nil { + return false + } + + for _, role := range roles { + if role.LabelID == folderID && role.CanEdit { + return true + } + } + + return false +} diff --git a/domain/link/endpoint.go b/domain/link/endpoint.go new file mode 100644 index 00000000..f843d8ff --- /dev/null +++ b/domain/link/endpoint.go @@ -0,0 +1,160 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package link + +import ( + "database/sql" + "net/http" + "net/url" + + "github.com/documize/community/core/env" + "github.com/documize/community/core/request" + "github.com/documize/community/core/response" + "github.com/documize/community/core/uniqueid" + "github.com/documize/community/domain" + "github.com/documize/community/domain/document" + "github.com/documize/community/model/attachment" + "github.com/documize/community/model/link" + "github.com/documize/community/model/page" +) + +// Handler contains the runtime information such as logging and database. +type Handler struct { + Runtime *env.Runtime + Store *domain.Store +} + +// GetLinkCandidates returns references to documents/sections/attachments. +func (h *Handler) GetLinkCandidates(w http.ResponseWriter, r *http.Request) { + method := "link.Candidates" + ctx := domain.GetRequestContext(r) + + folderID := request.Param(r, "folderID") + if len(folderID) == 0 { + response.WriteMissingDataError(w, method, "folderID") + return + } + + documentID := request.Param(r, "documentID") + if len(documentID) == 0 { + response.WriteMissingDataError(w, method, "documentID") + return + } + + pageID := request.Param(r, "pageID") + if len(pageID) == 0 { + response.WriteMissingDataError(w, method, "pageID") + return + } + + // permission check + if document.CanViewDocument(ctx, *h.Store, documentID) { + response.WriteForbiddenError(w) + return + } + + // We can link to a section within the same document so + // let's get all pages for the document and remove "us". + pages, err := h.Store.Page.GetPagesWithoutContent(ctx, documentID) + if err != nil && err != sql.ErrNoRows { + response.WriteServerError(w, method, err) + return + } + + if len(pages) == 0 { + pages = []page.Page{} + } + + pc := []link.Candidate{} + + for _, p := range pages { + if p.RefID != pageID { + c := link.Candidate{ + RefID: uniqueid.Generate(), + FolderID: folderID, + DocumentID: documentID, + TargetID: p.RefID, + LinkType: p.PageType, + Title: p.Title, + } + pc = append(pc, c) + } + } + + // We can link to attachment within the same document so + // let's get all attachments for the document. + files, err := h.Store.Attachment.GetAttachments(ctx, documentID) + if err != nil && err != sql.ErrNoRows { + response.WriteServerError(w, method, err) + return + } + + if len(files) == 0 { + files = []attachment.Attachment{} + } + + fc := []link.Candidate{} + + for _, f := range files { + c := link.Candidate{ + RefID: uniqueid.Generate(), + FolderID: folderID, + DocumentID: documentID, + TargetID: f.RefID, + LinkType: "file", + Title: f.Filename, + Context: f.Extension, + } + + fc = append(fc, c) + } + + var payload struct { + Pages []link.Candidate `json:"pages"` + Attachments []link.Candidate `json:"attachments"` + } + + payload.Pages = pc + payload.Attachments = fc + + response.WriteJSON(w, payload) +} + +// SearchLinkCandidates endpoint takes a list of keywords and returns a list of document references matching those keywords. +func (h *Handler) SearchLinkCandidates(w http.ResponseWriter, r *http.Request) { + method := "link.SearchLinkCandidates" + ctx := domain.GetRequestContext(r) + + keywords := request.Query(r, "keywords") + decoded, err := url.QueryUnescape(keywords) + if err != nil { + h.Runtime.Log.Error("decode query string", err) + } + + docs, pages, attachments, err := h.Store.Link.SearchCandidates(ctx, decoded) + if err != nil { + response.WriteServerError(w, method, err) + return + } + + var payload struct { + Documents []link.Candidate `json:"documents"` + Pages []link.Candidate `json:"pages"` + Attachments []link.Candidate `json:"attachments"` + } + + payload.Documents = docs + payload.Pages = pages + payload.Attachments = attachments + + response.WriteJSON(w, payload) +} diff --git a/domain/link/link.go b/domain/link/link.go index c8777e88..77713978 100644 --- a/domain/link/link.go +++ b/domain/link/link.go @@ -14,13 +14,13 @@ package link import ( "strings" - "github.com/documize/community/core/api/entity" + "github.com/documize/community/model/link" "golang.org/x/net/html" ) // GetContentLinks returns Documize generated links. // such links have an identifying attribute e.g. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package link + +import ( + "fmt" + "time" + + "github.com/documize/community/core/env" + "github.com/documize/community/core/streamutil" + "github.com/documize/community/core/uniqueid" + "github.com/documize/community/domain" + "github.com/documize/community/domain/store/mysql" + "github.com/documize/community/model/link" + "github.com/pkg/errors" +) + +// Scope provides data access to MySQL. +type Scope struct { + Runtime *env.Runtime +} + +// Add inserts wiki-link into the store. +// These links exist when content references another document or content. +func (s Scope) Add(ctx domain.RequestContext, l link.Link) (err error) { + l.Created = time.Now().UTC() + l.Revised = time.Now().UTC() + + stmt, err := ctx.Transaction.Preparex("INSERT INTO link (refid, orgid, folderid, userid, sourcedocumentid, sourcepageid, targetdocumentid, targetid, linktype, orphan, created, revised) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") + defer streamutil.Close(stmt) + + if err != nil { + err = errors.Wrap(err, "prepare link insert") + return + } + + _, err = stmt.Exec(l.RefID, l.OrgID, l.FolderID, l.UserID, l.SourceDocumentID, l.SourcePageID, l.TargetDocumentID, l.TargetID, l.LinkType, l.Orphan, l.Created, l.Revised) + if err != nil { + err = errors.Wrap(err, "execute link insert") + return + } + + return +} + +// GetDocumentOutboundLinks returns outbound links for specified document. +func (s Scope) GetDocumentOutboundLinks(ctx domain.RequestContext, documentID string) (links []link.Link, err error) { + err = s.Runtime.Db.Select(&links, + `select l.refid, l.orgid, l.folderid, l.userid, l.sourcedocumentid, l.sourcepageid, l.targetdocumentid, l.targetid, l.linktype, l.orphan, l.created, l.revised + FROM link l + WHERE l.orgid=? AND l.sourcedocumentid=?`, + ctx.OrgID, + documentID) + + if err != nil { + return + } + + if len(links) == 0 { + links = []link.Link{} + } + + return +} + +// GetPageLinks returns outbound links for specified page in document. +func (s Scope) GetPageLinks(ctx domain.RequestContext, documentID, pageID string) (links []link.Link, err error) { + err = s.Runtime.Db.Select(&links, + `select l.refid, l.orgid, l.folderid, l.userid, l.sourcedocumentid, l.sourcepageid, l.targetdocumentid, l.targetid, l.linktype, l.orphan, l.created, l.revised + FROM link l + WHERE l.orgid=? AND l.sourcedocumentid=? AND l.sourcepageid=?`, + ctx.OrgID, + documentID, + pageID) + + if err != nil { + return + } + + if len(links) == 0 { + links = []link.Link{} + } + + return +} + +// MarkOrphanDocumentLink marks all link records referencing specified document. +func (s Scope) MarkOrphanDocumentLink(ctx domain.RequestContext, documentID string) (err error) { + revised := time.Now().UTC() + + stmt, err := ctx.Transaction.Preparex("UPDATE link SET orphan=1, revised=? WHERE linktype='document' AND orgid=? AND targetdocumentid=?") + defer streamutil.Close(stmt) + + if err != nil { + return + } + + _, err = stmt.Exec(revised, ctx.OrgID, documentID) + + return +} + +// MarkOrphanPageLink marks all link records referencing specified page. +func (s Scope) MarkOrphanPageLink(ctx domain.RequestContext, pageID string) (err error) { + revised := time.Now().UTC() + + stmt, err := ctx.Transaction.Preparex("UPDATE link SET orphan=1, revised=? WHERE linktype='section' AND orgid=? AND targetid=?") + defer streamutil.Close(stmt) + + if err != nil { + return + } + + _, err = stmt.Exec(revised, ctx.OrgID, pageID) + + return +} + +// MarkOrphanAttachmentLink marks all link records referencing specified attachment. +func (s Scope) MarkOrphanAttachmentLink(ctx domain.RequestContext, attachmentID string) (err error) { + revised := time.Now().UTC() + + stmt, err := ctx.Transaction.Preparex("UPDATE link SET orphan=1, revised=? WHERE linktype='file' AND orgid=? AND targetid=?") + defer streamutil.Close(stmt) + + if err != nil { + return + } + + _, err = stmt.Exec(revised, ctx.OrgID, attachmentID) + + return +} + +// DeleteSourcePageLinks removes saved links for given source. +func (s Scope) DeleteSourcePageLinks(ctx domain.RequestContext, pageID string) (rows int64, err error) { + b := mysql.BaseQuery{} + return b.DeleteWhere(ctx.Transaction, fmt.Sprintf("DELETE FROM link WHERE orgid=\"%s\" AND sourcepageid=\"%s\"", ctx.OrgID, pageID)) +} + +// DeleteSourceDocumentLinks removes saved links for given document. +func (s Scope) DeleteSourceDocumentLinks(ctx domain.RequestContext, documentID string) (rows int64, err error) { + b := mysql.BaseQuery{} + return b.DeleteWhere(ctx.Transaction, fmt.Sprintf("DELETE FROM link WHERE orgid=\"%s\" AND sourcedocumentid=\"%s\"", ctx.OrgID, documentID)) +} + +// DeleteLink removes saved link from the store. +func (s Scope) DeleteLink(ctx domain.RequestContext, id string) (rows int64, err error) { + b := mysql.BaseQuery{} + return b.DeleteConstrained(ctx.Transaction, "link", ctx.OrgID, id) +} + +// SearchCandidates returns matching documents, sections and attachments using keywords. +func (s Scope) SearchCandidates(ctx domain.RequestContext, keywords string) (docs []link.Candidate, + pages []link.Candidate, attachments []link.Candidate, err error) { + // find matching documents + temp := []link.Candidate{} + likeQuery := "title LIKE '%" + keywords + "%'" + + err = s.Runtime.Db.Select(&temp, + `SELECT refid as documentid, labelid as folderid,title from document WHERE orgid=? AND `+likeQuery+` AND labelid IN + (SELECT refid from label WHERE orgid=? AND type=2 AND userid=? + UNION ALL SELECT refid FROM label a where orgid=? AND type=1 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid='' AND (canedit=1 OR canview=1)) + UNION ALL SELECT refid FROM label a where orgid=? AND type=3 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid=? AND (canedit=1 OR canview=1))) + ORDER BY title`, + ctx.OrgID, + ctx.OrgID, + ctx.UserID, + ctx.OrgID, + ctx.OrgID, + ctx.OrgID, + ctx.OrgID, + ctx.UserID) + + if err != nil { + err = errors.Wrap(err, "execute search links 1") + return + } + + for _, r := range temp { + c := link.Candidate{ + RefID: uniqueid.Generate(), + FolderID: r.FolderID, + DocumentID: r.DocumentID, + TargetID: r.DocumentID, + LinkType: "document", + Title: r.Title, + Context: "", + } + + docs = append(docs, c) + } + + // find matching sections + likeQuery = "p.title LIKE '%" + keywords + "%'" + temp = []link.Candidate{} + + err = s.Runtime.Db.Select(&temp, + `SELECT p.refid as targetid, p.documentid as documentid, p.title as title, p.pagetype as linktype, d.title as context, d.labelid as folderid from page p + LEFT JOIN document d ON d.refid=p.documentid WHERE p.orgid=? AND `+likeQuery+` AND d.labelid IN + (SELECT refid from label WHERE orgid=? AND type=2 AND userid=? + UNION ALL SELECT refid FROM label a where orgid=? AND type=1 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid='' AND (canedit=1 OR canview=1)) + UNION ALL SELECT refid FROM label a where orgid=? AND type=3 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid=? AND (canedit=1 OR canview=1))) + ORDER BY p.title`, + ctx.OrgID, + ctx.OrgID, + ctx.UserID, + ctx.OrgID, + ctx.OrgID, + ctx.OrgID, + ctx.OrgID, + ctx.UserID) + + if err != nil { + err = errors.Wrap(err, "execute search links 2") + return + } + + for _, r := range temp { + c := link.Candidate{ + RefID: uniqueid.Generate(), + FolderID: r.FolderID, + DocumentID: r.DocumentID, + TargetID: r.TargetID, + LinkType: r.LinkType, + Title: r.Title, + Context: r.Context, + } + + pages = append(pages, c) + } + + // find matching attachments + likeQuery = "a.filename LIKE '%" + keywords + "%'" + temp = []link.Candidate{} + + err = s.Runtime.Db.Select(&temp, + `SELECT a.refid as targetid, a.documentid as documentid, a.filename as title, a.extension as context, d.labelid as folderid from attachment a + LEFT JOIN document d ON d.refid=a.documentid WHERE a.orgid=? AND `+likeQuery+` AND d.labelid IN + (SELECT refid from label WHERE orgid=? AND type=2 AND userid=? + UNION ALL SELECT refid FROM label a where orgid=? AND type=1 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid='' AND (canedit=1 OR canview=1)) + UNION ALL SELECT refid FROM label a where orgid=? AND type=3 AND refid IN (SELECT labelid from labelrole WHERE orgid=? AND userid=? AND (canedit=1 OR canview=1))) + ORDER BY a.filename`, + ctx.OrgID, + ctx.OrgID, + ctx.UserID, + ctx.OrgID, + ctx.OrgID, + ctx.OrgID, + ctx.OrgID, + ctx.UserID) + + if err != nil { + err = errors.Wrap(err, "execute search links 3") + return + } + + for _, r := range temp { + c := link.Candidate{ + RefID: uniqueid.Generate(), + FolderID: r.FolderID, + DocumentID: r.DocumentID, + TargetID: r.TargetID, + LinkType: "file", + Title: r.Title, + Context: r.Context, + } + + attachments = append(attachments, c) + } + + if len(docs) == 0 { + docs = []link.Candidate{} + } + if len(pages) == 0 { + pages = []link.Candidate{} + } + if len(attachments) == 0 { + attachments = []link.Candidate{} + } + + return +} diff --git a/domain/meta/endpoint.go b/domain/meta/endpoint.go new file mode 100644 index 00000000..58fa1668 --- /dev/null +++ b/domain/meta/endpoint.go @@ -0,0 +1,182 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package meta + +import ( + "bytes" + "fmt" + "net/http" + "text/template" + + "github.com/documize/community/core/env" + "github.com/documize/community/core/log" + "github.com/documize/community/core/response" + "github.com/documize/community/core/stringutil" + "github.com/documize/community/domain" + "github.com/documize/community/domain/auth" + "github.com/documize/community/domain/organization" + "github.com/documize/community/model/doc" + "github.com/documize/community/model/org" + "github.com/documize/community/model/space" +) + +// Handler contains the runtime information such as logging and database. +type Handler struct { + Runtime *env.Runtime + Store *domain.Store +} + +// Meta provides org meta data based upon request domain (e.g. acme.documize.com). +func (h *Handler) Meta(w http.ResponseWriter, r *http.Request) { + ctx := domain.GetRequestContext(r) + + data := org.SiteMeta{} + data.URL = organization.GetSubdomainFromHost(r) + + org, err := h.Store.Organization.GetOrganizationByDomain(ctx, data.URL) + if err != nil { + response.WriteForbiddenError(w) + return + } + + data.OrgID = org.RefID + data.Title = org.Title + data.Message = org.Message + data.AllowAnonymousAccess = org.AllowAnonymousAccess + data.AuthProvider = org.AuthProvider + data.AuthConfig = org.AuthConfig + data.Version = h.Runtime.Product.Version + data.Edition = h.Runtime.Product.License.Edition + data.Valid = h.Runtime.Product.License.Valid + data.ConversionEndpoint = org.ConversionEndpoint + + // Strip secrets + data.AuthConfig = auth.StripAuthSecrets(h.Runtime, org.AuthProvider, org.AuthConfig) + + response.WriteJSON(w, data) +} + +// RobotsTxt returns robots.txt depending on site configuration. +// Did we allow anonymouse access? +func (h *Handler) RobotsTxt(w http.ResponseWriter, r *http.Request) { + method := "GetRobots" + ctx := domain.GetRequestContext(r) + + dom := organization.GetSubdomainFromHost(r) + org, err := h.Store.Organization.GetOrganizationByDomain(ctx, dom) + + // default is to deny + robots := + `User-agent: * +Disallow: / +` + + if err != nil { + h.Runtime.Log.Error(fmt.Sprintf("%s failed to get Organization for domain %s", method, dom), err) + } + + // Anonymous access would mean we allow bots to crawl. + if org.AllowAnonymousAccess { + sitemap := ctx.GetAppURL("sitemap.xml") + robots = fmt.Sprintf( + `User-agent: * +Disallow: /settings/ +Disallow: /settings/* +Disallow: /profile/ +Disallow: /profile/* +Disallow: /auth/login/ +Disallow: /auth/login/ +Disallow: /auth/logout/ +Disallow: /auth/logout/* +Disallow: /auth/reset/* +Disallow: /auth/reset/* +Disallow: /auth/sso/ +Disallow: /auth/sso/* +Disallow: /share +Disallow: /share/* +Sitemap: %s`, sitemap) + } + + response.WriteBytes(w, []byte(robots)) +} + +// Sitemap returns URLs that can be indexed. +// We only include public folders and documents (e.g. can be seen by everyone). +func (h *Handler) Sitemap(w http.ResponseWriter, r *http.Request) { + method := "meta.Sitemap" + ctx := domain.GetRequestContext(r) + + dom := organization.GetSubdomainFromHost(r) + org, err := h.Store.Organization.GetOrganizationByDomain(ctx, dom) + + if err != nil { + h.Runtime.Log.Error(fmt.Sprintf("%s failed to get Organization for domain %s", method, dom), err) + } + + sitemap := + ` + + {{range .}} + {{ .URL }} + {{ .Date }} + {{end}} +` + + var items []sitemapItem + + // Anonymous access means we announce folders/documents shared with 'Everyone'. + if org.AllowAnonymousAccess { + // Grab shared folders + folders, err := h.Store.Space.PublicSpaces(ctx, org.RefID) + if err != nil { + folders = []space.Space{} + h.Runtime.Log.Error(fmt.Sprintf("%s failed to get folders for domain %s", method, dom), err) + } + + for _, folder := range folders { + var item sitemapItem + item.URL = ctx.GetAppURL(fmt.Sprintf("s/%s/%s", folder.RefID, stringutil.MakeSlug(folder.Name))) + item.Date = folder.Revised.Format("2006-01-02T15:04:05.999999-07:00") + items = append(items, item) + } + + // Grab documents from shared folders + var documents []doc.SitemapDocument + documents, err = h.Store.Document.PublicDocuments(ctx, org.RefID) + if err != nil { + documents = []doc.SitemapDocument{} + h.Runtime.Log.Error(fmt.Sprintf("%s failed to get documents for domain %s", method, dom), err) + } + + for _, document := range documents { + var item sitemapItem + item.URL = ctx.GetAppURL(fmt.Sprintf("s/%s/%s/d/%s/%s", + document.FolderID, stringutil.MakeSlug(document.Folder), document.DocumentID, stringutil.MakeSlug(document.Document))) + item.Date = document.Revised.Format("2006-01-02T15:04:05.999999-07:00") + items = append(items, item) + } + + } + + buffer := new(bytes.Buffer) + t := template.Must(template.New("tmp").Parse(sitemap)) + log.IfErr(t.Execute(buffer, &items)) + + response.WriteBytes(w, buffer.Bytes()) +} + +// sitemapItem provides a means to teleport somewhere else for free. +// What did you think it did? +type sitemapItem struct { + URL string + Date string +} diff --git a/domain/page/mysql/store.go b/domain/page/mysql/store.go new file mode 100644 index 00000000..d1e02d0f --- /dev/null +++ b/domain/page/mysql/store.go @@ -0,0 +1,39 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package mysql + +import ( + "fmt" + + "github.com/documize/community/core/env" + "github.com/documize/community/domain" + "github.com/documize/community/model/page" + "github.com/pkg/errors" +) + +// Scope provides data access to MySQL. +type Scope struct { + Runtime *env.Runtime +} + +// GetPagesWithoutContent returns a slice containing all the page records for a given documentID, in presentation sequence, +// but without the body field (which holds the HTML content). +func (s Scope) GetPagesWithoutContent(ctx domain.RequestContext, documentID string) (pages []page.Page, err error) { + err = s.Runtime.Db.Select(&pages, "SELECT id, refid, orgid, documentid, userid, contenttype, pagetype, sequence, level, title, revisions, blockid, created, revised FROM page WHERE orgid=? AND documentid=? ORDER BY sequence", ctx.OrgID, documentID) + + if err != nil { + err = errors.Wrap(err, fmt.Sprintf("Unable to execute select pages for org %s and document %s", ctx.OrgID, documentID)) + return + } + + return +} diff --git a/domain/setting/endpoint.go b/domain/setting/endpoint.go new file mode 100644 index 00000000..dc1b9180 --- /dev/null +++ b/domain/setting/endpoint.go @@ -0,0 +1,248 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +// Package setting manages both global and user level settings +package setting + +import ( + "encoding/json" + "encoding/xml" + "io/ioutil" + "net/http" + + "github.com/documize/community/core/env" + "github.com/documize/community/core/event" + "github.com/documize/community/core/response" + "github.com/documize/community/domain" + "github.com/documize/community/model/audit" +) + +// Handler contains the runtime information such as logging and database. +type Handler struct { + Runtime *env.Runtime + Store *domain.Store +} + +// SMTP returns installation-wide SMTP settings +func (h *Handler) SMTP(w http.ResponseWriter, r *http.Request) { + method := "setting.SMTP" + ctx := domain.GetRequestContext(r) + + if !ctx.Global { + response.WriteForbiddenError(w) + return + } + + config := h.Store.Setting.Get(ctx, "SMTP", "") + + var y map[string]interface{} + json.Unmarshal([]byte(config), &y) + + j, err := json.Marshal(y) + if err != nil { + response.WriteBadRequestError(w, method, err.Error()) + return + } + + response.WriteBytes(w, j) +} + +// SetSMTP persists global SMTP configuration. +func (h *Handler) SetSMTP(w http.ResponseWriter, r *http.Request) { + method := "setting.SetSMTP" + ctx := domain.GetRequestContext(r) + + if !ctx.Global { + response.WriteForbiddenError(w) + return + } + + defer r.Body.Close() + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + response.WriteBadRequestError(w, method, err.Error()) + return + } + + var config string + config = string(body) + + ctx.Transaction, err = h.Runtime.Db.Beginx() + if err != nil { + response.WriteServerError(w, method, err) + return + } + + h.Store.Setting.Set(ctx, "SMTP", config) + + h.Store.Audit.Record(ctx, audit.EventTypeSystemSMTP) + + response.WriteEmpty(w) +} + +// License returns product license +func (h *Handler) License(w http.ResponseWriter, r *http.Request) { + ctx := domain.GetRequestContext(r) + + if !ctx.Global { + response.WriteForbiddenError(w) + return + } + + config := h.Store.Setting.Get(ctx, "EDITION-LICENSE", "") + if len(config) == 0 { + config = "{}" + } + + x := &licenseXML{Key: "", Signature: ""} + lj := licenseJSON{} + + err := json.Unmarshal([]byte(config), &lj) + if err == nil { + x.Key = lj.Key + x.Signature = lj.Signature + } else { + h.Runtime.Log.Error("failed to JSON unmarshal EDITION-LICENSE", err) + } + + output, err := xml.Marshal(x) + if err != nil { + h.Runtime.Log.Error("failed to XML marshal EDITION-LICENSE", err) + } + + response.WriteBytes(w, output) +} + +// SetLicense persists product license +func (h *Handler) SetLicense(w http.ResponseWriter, r *http.Request) { + method := "setting.SetLicense" + ctx := domain.GetRequestContext(r) + + if !ctx.Global { + response.WriteForbiddenError(w) + return + } + + defer r.Body.Close() + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + response.WriteBadRequestError(w, method, err.Error()) + return + } + + var config string + config = string(body) + lj := licenseJSON{} + x := licenseXML{Key: "", Signature: ""} + + err = xml.Unmarshal([]byte(config), &x) + if err == nil { + lj.Key = x.Key + lj.Signature = x.Signature + } else { + h.Runtime.Log.Error("failed to XML unmarshal EDITION-LICENSE", err) + } + + j, err := json.Marshal(lj) + js := "{}" + if err == nil { + js = string(j) + } + + h.Store.Setting.Set(ctx, "EDITION-LICENSE", js) + + event.Handler().Publish(string(event.TypeSystemLicenseChange)) + + ctx.Transaction, err = h.Runtime.Db.Beginx() + if err != nil { + response.WriteServerError(w, method, err) + return + } + + h.Store.Audit.Record(ctx, audit.EventTypeSystemLicense) + ctx.Transaction.Commit() + + response.WriteEmpty(w) +} + +// AuthConfig returns installation-wide auth configuration +func (h *Handler) AuthConfig(w http.ResponseWriter, r *http.Request) { + ctx := domain.GetRequestContext(r) + + if !ctx.Global { + response.WriteForbiddenError(w) + return + } + + org, err := h.Store.Organization.GetOrganization(ctx, ctx.OrgID) + if err != nil { + response.WriteForbiddenError(w) + return + } + + response.WriteJSON(w, org.AuthConfig) +} + +// SetAuthConfig persists installation-wide authentication configuration +func (h *Handler) SetAuthConfig(w http.ResponseWriter, r *http.Request) { + method := "SaveAuthConfig" + ctx := domain.GetRequestContext(r) + + if !ctx.Global { + response.WriteForbiddenError(w) + return + } + + defer r.Body.Close() + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + response.WriteBadRequestError(w, method, err.Error()) + return + } + + var data authData + err = json.Unmarshal(body, &data) + if err != nil { + response.WriteBadRequestError(w, method, err.Error()) + return + } + + org, err := h.Store.Organization.GetOrganization(ctx, ctx.OrgID) + if err != nil { + response.WriteServerError(w, method, err) + return + } + + org.AuthProvider = data.AuthProvider + org.AuthConfig = data.AuthConfig + + ctx.Transaction, err = h.Runtime.Db.Beginx() + if err != nil { + response.WriteServerError(w, method, err) + return + } + + err = h.Store.Organization.UpdateAuthConfig(ctx, org) + if err != nil { + ctx.Transaction.Rollback() + response.WriteServerError(w, method, err) + return + } + + h.Store.Audit.Record(ctx, audit.EventTypeSystemAuth) + + ctx.Transaction.Commit() + + response.WriteEmpty(w) +} diff --git a/domain/setting/model.go b/domain/setting/model.go new file mode 100644 index 00000000..844c9758 --- /dev/null +++ b/domain/setting/model.go @@ -0,0 +1,38 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +// Package setting manages both global and user level settings +package setting + +import "encoding/xml" + +type licenseXML struct { + XMLName xml.Name `xml:"DocumizeLicense"` + Key string + Signature string +} + +type licenseJSON struct { + Key string `json:"key"` + Signature string `json:"signature"` +} + +type authData struct { + AuthProvider string `json:"authProvider"` + AuthConfig string `json:"authConfig"` +} + +/* + + some key + some signature + +*/ diff --git a/domain/setting/mysql/setting.go b/domain/setting/mysql/setting.go new file mode 100644 index 00000000..2c9bb641 --- /dev/null +++ b/domain/setting/mysql/setting.go @@ -0,0 +1,135 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package mysql + +import ( + "bytes" + "fmt" + + "github.com/documize/community/core/env" + "github.com/documize/community/core/streamutil" + "github.com/documize/community/domain" + "github.com/pkg/errors" +) + +// Scope provides data access to MySQL. +type Scope struct { + Runtime *env.Runtime +} + +// Get fetches a configuration JSON element from the config table. +func (s Scope) Get(ctx domain.RequestContext, area, path string) (value string) { + if path != "" { + path = "." + path + } + sql := "SELECT JSON_EXTRACT(`config`,'$" + path + "') FROM `config` WHERE `key` = '" + area + "';" + + stmt, err := s.Runtime.Db.Preparex(sql) + defer streamutil.Close(stmt) + + if err != nil { + s.Runtime.Log.Error(fmt.Sprintf("setting.Get %s %s", area, path), err) + return "" + } + + var item = make([]uint8, 0) + + err = stmt.Get(&item) + if err != nil { + s.Runtime.Log.Error(fmt.Sprintf("setting.Get %s %s", area, path), err) + return "" + } + + if len(item) > 1 { + q := []byte(`"`) + value = string(bytes.TrimPrefix(bytes.TrimSuffix(item, q), q)) + } + + return value +} + +// Set writes a configuration JSON element to the config table. +func (s Scope) Set(ctx domain.RequestContext, area, json string) error { + if area == "" { + return errors.New("no area") + } + + sql := "INSERT INTO `config` (`key`,`config`) " + + "VALUES ('" + area + "','" + json + + "') ON DUPLICATE KEY UPDATE `config`='" + json + "';" + + stmt, err := s.Runtime.Db.Preparex(sql) + defer streamutil.Close(stmt) + + if err != nil { + err = errors.Wrap(err, "failed to save global config value") + return err + } + + _, err = stmt.Exec() + return err +} + +// GetUser fetches a configuration JSON element from the userconfig table for a given orgid/userid combination. +// Errors return the empty string. A blank path returns the whole JSON object, as JSON. +func (s Scope) GetUser(ctx domain.RequestContext, orgID, userID, area, path string) (value string) { + if path != "" { + path = "." + path + } + + sql := "SELECT JSON_EXTRACT(`config`,'$" + path + "') FROM `userconfig` WHERE `key` = '" + area + + "' AND `orgid` = '" + orgID + "' AND `userid` = '" + userID + "';" + + stmt, err := s.Runtime.Db.Preparex(sql) + defer streamutil.Close(stmt) + + if err != nil { + return "" + } + + var item = make([]uint8, 0) + + err = stmt.Get(&item) + if err != nil { + s.Runtime.Log.Error(fmt.Sprintf("setting.GetUser for user %s %s %s", userID, area, path), err) + return "" + } + + if len(item) > 1 { + q := []byte(`"`) + value = string(bytes.TrimPrefix(bytes.TrimSuffix(item, q), q)) + } + + return value +} + +// SetUser writes a configuration JSON element to the userconfig table for the current user. +func (s Scope) SetUser(ctx domain.RequestContext, orgID, userID, area, json string) error { + if area == "" { + return errors.New("no area") + } + + sql := "INSERT INTO `userconfig` (`orgid`,`userid`,`key`,`config`) " + + "VALUES ('" + orgID + "','" + userID + "','" + area + "','" + json + + "') ON DUPLICATE KEY UPDATE `config`='" + json + "';" + + stmt, err := s.Runtime.Db.Preparex(sql) + defer streamutil.Close(stmt) + + if err != nil { + return err + } + + _, err = stmt.Exec() + + return err +} diff --git a/domain/storer.go b/domain/storer.go index 83d6eeaa..6d6f3b72 100644 --- a/domain/storer.go +++ b/domain/storer.go @@ -14,8 +14,12 @@ package domain import ( "github.com/documize/community/model/account" + "github.com/documize/community/model/attachment" "github.com/documize/community/model/audit" + "github.com/documize/community/model/doc" + "github.com/documize/community/model/link" "github.com/documize/community/model/org" + "github.com/documize/community/model/page" "github.com/documize/community/model/pin" "github.com/documize/community/model/space" "github.com/documize/community/model/user" @@ -30,6 +34,10 @@ type Store struct { Pin PinStorer Audit AuditStorer Document DocumentStorer + Setting SettingStorer + Attachment AttachmentStorer + Link LinkStorer + Page PageStorer } // SpaceStorer defines required methods for space management @@ -112,7 +120,43 @@ type AuditStorer interface { // DocumentStorer defines required methods for document handling type DocumentStorer interface { + Get(ctx RequestContext, id string) (document doc.Document, err error) MoveDocumentSpace(ctx RequestContext, id, move string) (err error) + PublicDocuments(ctx RequestContext, orgID string) (documents []doc.SitemapDocument, err error) } -// https://github.com/golang-sql/sqlexp/blob/c2488a8be21d20d31abf0d05c2735efd2d09afe4/quoter.go#L46 +// SettingStorer defines required methods for persisting global and user level settings +type SettingStorer interface { + Get(ctx RequestContext, area, path string) string + Set(ctx RequestContext, area, value string) error + GetUser(ctx RequestContext, orgID, userID, area, path string) string + SetUser(ctx RequestContext, orgID, userID, area, json string) error +} + +// AttachmentStorer defines required methods for persisting document attachments +type AttachmentStorer interface { + Add(ctx RequestContext, a attachment.Attachment) (err error) + GetAttachment(ctx RequestContext, orgID, attachmentID string) (a attachment.Attachment, err error) + GetAttachments(ctx RequestContext, docID string) (a []attachment.Attachment, err error) + GetAttachmentsWithData(ctx RequestContext, docID string) (a []attachment.Attachment, err error) + Delete(ctx RequestContext, id string) (rows int64, err error) +} + +// LinkStorer defines required methods for persisting content links +type LinkStorer interface { + Add(ctx RequestContext, l link.Link) (err error) + SearchCandidates(ctx RequestContext, keywords string) (docs []link.Candidate, pages []link.Candidate, attachments []link.Candidate, err error) + GetDocumentOutboundLinks(ctx RequestContext, documentID string) (links []link.Link, err error) + GetPageLinks(ctx RequestContext, documentID, pageID string) (links []link.Link, err error) + MarkOrphanDocumentLink(ctx RequestContext, documentID string) (err error) + MarkOrphanPageLink(ctx RequestContext, pageID string) (err error) + MarkOrphanAttachmentLink(ctx RequestContext, attachmentID string) (err error) + DeleteSourcePageLinks(ctx RequestContext, pageID string) (rows int64, err error) + DeleteSourceDocumentLinks(ctx RequestContext, documentID string) (rows int64, err error) + DeleteLink(ctx RequestContext, id string) (rows int64, err error) +} + +// PageStorer defines required methods for persisting document pages +type PageStorer interface { + GetPagesWithoutContent(ctx RequestContext, documentID string) (pages []page.Page, err error) +} diff --git a/domain/template/mysql/store.go b/domain/template/mysql/store.go new file mode 100644 index 00000000..e69de29b diff --git a/domain/user/endpoint.go b/domain/user/endpoint.go index fa30dbf3..847e2e67 100644 --- a/domain/user/endpoint.go +++ b/domain/user/endpoint.go @@ -12,27 +12,48 @@ package user import ( + "database/sql" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "strconv" + + "github.com/documize/community/core/api/mail" "github.com/documize/community/core/env" + "github.com/documize/community/core/event" + "github.com/documize/community/core/request" + "github.com/documize/community/core/response" + "github.com/documize/community/core/secrets" + "github.com/documize/community/core/streamutil" + "github.com/documize/community/core/stringutil" + "github.com/documize/community/core/uniqueid" "github.com/documize/community/domain" + "github.com/documize/community/model/account" + "github.com/documize/community/model/audit" + "github.com/documize/community/model/space" + "github.com/documize/community/model/user" ) // Handler contains the runtime information such as logging and database. type Handler struct { Runtime *env.Runtime - Store domain.Store + Store *domain.Store } -/* -// AddUser is the endpoint that enables an administrator to add a new user for their orgaisation. -func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { - method := "user.AddUser" +// Add is the endpoint that enables an administrator to add a new user for their orgaisation. +func (h *Handler) Add(w http.ResponseWriter, r *http.Request) { + method := "user.Add" ctx := domain.GetRequestContext(r) if !h.Runtime.Product.License.IsValid() { response.WriteBadLicense(w) } - if !s.Context.Administrator { + if !ctx.Administrator { response.WriteForbiddenError(w) return } @@ -44,7 +65,7 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { return } - userModel := model.User{} + userModel := user.User{} err = json.Unmarshal(body, &userModel) if err != nil { response.WriteBadRequestError(w, method, err.Error()) @@ -82,7 +103,7 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { addAccount := true var userID string - userDupe, err := h.Store.User.GetByEmail(s, userModel.Email) + userDupe, err := h.Store.User.GetByEmail(ctx, userModel.Email) if err != nil && err != sql.ErrNoRows { response.WriteServerError(w, method, err) return @@ -95,7 +116,7 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { h.Runtime.Log.Info("Dupe user found, will not add " + userModel.Email) } - s.Context.Transaction, err = request.Db.Beginx() + ctx.Transaction, err = h.Runtime.Db.Beginx() if err != nil { response.WriteServerError(w, method, err) return @@ -105,19 +126,19 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { userID = uniqueid.Generate() userModel.RefID = userID - err = h.Store.User.Add(s, userModel) + err = h.Store.User.Add(ctx, userModel) if err != nil { - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() response.WriteServerError(w, method, err) return } h.Runtime.Log.Info("Adding user") } else { - AttachUserAccounts(s, s.Context.OrgID, &userDupe) + AttachUserAccounts(ctx, *h.Store, ctx.OrgID, &userDupe) for _, a := range userDupe.Accounts { - if a.OrgID == s.Context.OrgID { + if a.OrgID == ctx.OrgID { addAccount = false h.Runtime.Log.Info("Dupe account found, will not add") break @@ -127,17 +148,17 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { // set up user account for the org if addAccount { - var a model.Account + var a account.Account a.RefID = uniqueid.Generate() a.UserID = userID - a.OrgID = s.Context.OrgID + a.OrgID = ctx.OrgID a.Editor = true a.Admin = false a.Active = true - err = account.Add(s, a) + err = h.Store.Account.Add(ctx, a) if err != nil { - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() response.WriteServerError(w, method, err) return } @@ -145,15 +166,15 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { if addUser { event.Handler().Publish(string(event.TypeAddUser)) - eventing.Record(s, eventing.EventTypeUserAdd) + h.Store.Audit.Record(ctx, audit.EventTypeUserAdd) } if addAccount { event.Handler().Publish(string(event.TypeAddAccount)) - eventing.Record(s, eventing.EventTypeAccountAdd) + h.Store.Audit.Record(ctx, audit.EventTypeAccountAdd) } - s.Context.Transaction.Commit() + ctx.Transaction.Commit() // If we did not add user or give them access (account) then we error back if !addUser && !addAccount { @@ -162,7 +183,7 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { } // Invite new user - inviter, err := h.Store.User.Get(s, s.Context.UserID) + inviter, err := h.Store.User.Get(ctx, ctx.UserID) if err != nil { response.WriteServerError(w, method, err) return @@ -172,16 +193,16 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { if addUser && addAccount { size := len(requestedPassword) - auth := fmt.Sprintf("%s:%s:%s", s.Context.AppURL, userModel.Email, requestedPassword[:size]) + auth := fmt.Sprintf("%s:%s:%s", ctx.AppURL, userModel.Email, requestedPassword[:size]) encrypted := secrets.EncodeBase64([]byte(auth)) - url := fmt.Sprintf("%s/%s", s.Context.GetAppURL("auth/sso"), url.QueryEscape(string(encrypted))) + url := fmt.Sprintf("%s/%s", ctx.GetAppURL("auth/sso"), url.QueryEscape(string(encrypted))) go mail.InviteNewUser(userModel.Email, inviter.Fullname(), url, userModel.Email, requestedPassword) - h.Runtime.Log.Info(fmt.Sprintf("%s invited by %s on %s", userModel.Email, inviter.Email, s.Context.AppURL)) + h.Runtime.Log.Info(fmt.Sprintf("%s invited by %s on %s", userModel.Email, inviter.Email, ctx.AppURL)) } else { - go mail.InviteExistingUser(userModel.Email, inviter.Fullname(), s.Context.GetAppURL("")) + go mail.InviteExistingUser(userModel.Email, inviter.Fullname(), ctx.GetAppURL("")) h.Runtime.Log.Info(fmt.Sprintf("%s is giving access to an existing user %s", inviter.Email, userModel.Email)) } @@ -189,33 +210,32 @@ func (h *Handler) AddUser(w http.ResponseWriter, r *http.Request) { response.WriteJSON(w, userModel) } -/* // GetOrganizationUsers is the endpoint that allows administrators to view the users in their organisation. func (h *Handler) GetOrganizationUsers(w http.ResponseWriter, r *http.Request) { - method := "pin.GetUserPins" - s := domain.NewContext(h.Runtime, r) + method := "user.GetOrganizationUsers" + ctx := domain.GetRequestContext(r) - if !s.Context.Editor && !s.Context.Administrator { + if !ctx.Administrator { response.WriteForbiddenError(w) return } - active, err := strconv.ParseBool(request.Query("active")) + active, err := strconv.ParseBool(request.Query(r, "active")) if err != nil { active = false } - u := []User{} + u := []user.User{} if active { - u, err = GetActiveUsersForOrganization(s) + u, err = h.Store.User.GetActiveUsersForOrganization(ctx) if err != nil && err != sql.ErrNoRows { response.WriteServerError(w, method, err) return } } else { - u, err = GetUsersForOrganization(s) + u, err = h.Store.User.GetUsersForOrganization(ctx) if err != nil && err != sql.ErrNoRows { response.WriteServerError(w, method, err) return @@ -223,11 +243,11 @@ func (h *Handler) GetOrganizationUsers(w http.ResponseWriter, r *http.Request) { } if len(u) == 0 { - u = []User{} + u = []user.User{} } for i := range u { - AttachUserAccounts(s, s.Context.OrgID, &u[i]) + AttachUserAccounts(ctx, *h.Store, ctx.OrgID, &u[i]) } response.WriteJSON(w, u) @@ -236,19 +256,19 @@ func (h *Handler) GetOrganizationUsers(w http.ResponseWriter, r *http.Request) { // GetSpaceUsers returns every user within a given space func (h *Handler) GetSpaceUsers(w http.ResponseWriter, r *http.Request) { method := "user.GetSpaceUsers" - s := domain.NewContext(h.Runtime, r) + ctx := domain.GetRequestContext(r) - var u []User + var u []user.User var err error - folderID := request.Param("folderID") + folderID := request.Param(r, "folderID") if len(folderID) == 0 { response.WriteMissingDataError(w, method, "folderID") return } // check to see space type as it determines user selection criteria - folder, err := space.Get(s, folderID) + folder, err := h.Store.Space.Get(ctx, folderID) if err != nil && err != sql.ErrNoRows { h.Runtime.Log.Error("cannot get space", err) response.WriteJSON(w, u) @@ -256,22 +276,22 @@ func (h *Handler) GetSpaceUsers(w http.ResponseWriter, r *http.Request) { } switch folder.Type { - case entity.FolderTypePublic: - u, err = GetActiveUsersForOrganization(s) + case space.ScopePublic: + u, err = h.Store.User.GetActiveUsersForOrganization(ctx) break - case entity.FolderTypePrivate: + case space.ScopePrivate: // just me - var me User - user, err = Get(s, s.Context.UserID) + var me user.User + me, err = h.Store.User.Get(ctx, ctx.UserID) u = append(u, me) break - case entity.FolderTypeRestricted: - u, err = GetSpaceUsers(s, folderID) + case space.ScopeRestricted: + u, err = h.Store.User.GetSpaceUsers(ctx, folderID) break } if len(u) == 0 { - u = []User + u = []user.User{} } if err != nil && err != sql.ErrNoRows { @@ -283,25 +303,25 @@ func (h *Handler) GetSpaceUsers(w http.ResponseWriter, r *http.Request) { response.WriteJSON(w, u) } -// GetUser returns user specified by ID -func (h *Handler) GetUser(w http.ResponseWriter, r *http.Request) { - method := "user.GetUser" - s := domain.NewContext(h.Runtime, r) +// Get returns user specified by ID +func (h *Handler) Get(w http.ResponseWriter, r *http.Request) { + method := "user.Get" + ctx := domain.GetRequestContext(r) - userID := request.Param("userID") + userID := request.Param(r, "userID") if len(userID) == 0 { response.WriteMissingDataError(w, method, "userId") return } - if userID != s.Context.UserID { + if userID != ctx.UserID { response.WriteBadRequestError(w, method, "userId mismatch") return } - u, err := GetSecuredUser(s, s.Context.OrgID, userID) + u, err := GetSecuredUser(ctx, *h.Store, ctx.OrgID, userID) if err == sql.ErrNoRows { - response.WriteNotFoundError(s, method, s.Context.UserID) + response.WriteNotFoundError(w, method, ctx.UserID) return } if err != nil { @@ -309,66 +329,63 @@ func (h *Handler) GetUser(w http.ResponseWriter, r *http.Request) { return } - response.WriteJSON(u) + response.WriteJSON(w, u) +} -// DeleteUser is the endpoint to delete a user specified by userID, the caller must be an Administrator. -func (h *Handler) DeleteUser(w http.ResponseWriter, r *http.Request) { - method := "user.DeleteUser" - s := domain.NewContext(h.Runtime, r) +// Delete is the endpoint to delete a user specified by userID, the caller must be an Administrator. +func (h *Handler) Delete(w http.ResponseWriter, r *http.Request) { + method := "user.Delete" + ctx := domain.GetRequestContext(r) - if !s.Context.Administrator { - response.WriteForbiddenError(w) - return - } - - userID := response.Params("userID") + userID := request.Param(r, "userID") if len(userID) == 0 { - response.WriteMissingDataError(w, method, "userID") + response.WriteMissingDataError(w, method, "userId") return } - if userID == s.Context.UserID { + if userID == ctx.UserID { response.WriteBadRequestError(w, method, "cannot delete self") return } var err error - s.Context.Transaction, err = h.Runtime.Db.Beginx() + ctx.Transaction, err = h.Runtime.Db.Beginx() if err != nil { response.WriteServerError(w, method, err) return } - err = DeactiveUser(s, userID) + err = h.Store.User.DeactiveUser(ctx, userID) if err != nil { - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() response.WriteServerError(w, method, err) return } - err = space.ChangeLabelOwner(s, userID, s.Context.UserID) + err = h.Store.Space.ChangeOwner(ctx, userID, ctx.UserID) if err != nil { - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() response.WriteServerError(w, method, err) return } - eventing.Record(s, eventing.EventTypeUserDelete) + h.Store.Audit.Record(ctx, audit.EventTypeUserDelete) + event.Handler().Publish(string(event.TypeRemoveUser)) - s.Context.Transaction.Commit() + ctx.Transaction.Commit() - response.WriteEmpty() + response.WriteEmpty(w) } -// UpdateUser is the endpoint to update user information for the given userID. +// Update is the endpoint to update user information for the given userID. // Note that unless they have admin privildges, a user can only update their own information. // Also, only admins can update user roles in organisations. -func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { - method := "user.DeleteUser" - s := domain.NewContext(h.Runtime, r) +func (h *Handler) Update(w http.ResponseWriter, r *http.Request) { + method := "user.Update" + ctx := domain.GetRequestContext(r) - userID := request.Param("userID") + userID := request.Param(r, "userID") if len(userID) == 0 { response.WriteBadRequestError(w, method, "user id must be numeric") return @@ -377,11 +394,11 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { defer streamutil.Close(r.Body) body, err := ioutil.ReadAll(r.Body) if err != nil { - response.WritePayloadError(w, method, err) + response.WriteBadRequestError(w, method, err.Error()) return } - u := User{} + u := user.User{} err = json.Unmarshal(body, &u) if err != nil { response.WriteBadRequestError(w, method, err.Error()) @@ -389,7 +406,7 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { } // can only update your own account unless you are an admin - if s.Context.UserID != userID && !s.Context.Administrator { + if ctx.UserID != userID && !ctx.Administrator { response.WriteForbiddenError(w) return } @@ -400,7 +417,7 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - s.Context.Transaction, err = h.Runtime.Db.Beginx() + ctx.Transaction, err = h.Runtime.Db.Beginx() if err != nil { response.WriteServerError(w, method, err) return @@ -409,9 +426,9 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { u.RefID = userID u.Initials = stringutil.MakeInitials(u.Firstname, u.Lastname) - err = UpdateUser(s, u) + err = h.Store.User.UpdateUser(ctx, u) if err != nil { - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() response.WriteServerError(w, method, err) return } @@ -419,9 +436,9 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { // Now we update user roles for this organization. // That means we have to first find their account record // for this organization. - a, err := account.GetUserAccount(s, userID) + a, err := h.Store.Account.GetUserAccount(ctx, userID) if err != nil { - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() response.WriteServerError(w, method, err) return } @@ -430,26 +447,26 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request) { a.Admin = u.Admin a.Active = u.Active - err = account.UpdateAccount(s, account) + err = h.Store.Account.UpdateAccount(ctx, a) if err != nil { - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() response.WriteServerError(w, method, err) return } - eventing.Record(s, eventing.EventTypeUserUpdate) + h.Store.Audit.Record(ctx, audit.EventTypeUserUpdate) - s.Context.Transaction.Commit() + ctx.Transaction.Commit() - response.WriteJSON(u) + response.WriteEmpty(w) } -// ChangeUserPassword accepts password change from within the app. -func (h *Handler) ChangeUserPassword(w http.ResponseWriter, r *http.Request) { - method := "user.ChangeUserPassword" - s := domain.NewContext(h.Runtime, r) +// ChangePassword accepts password change from within the app. +func (h *Handler) ChangePassword(w http.ResponseWriter, r *http.Request) { + method := "user.ChangePassword" + ctx := domain.GetRequestContext(r) - userID := response.Param("userID") + userID := request.Param(r, "userID") if len(userID) == 0 { response.WriteMissingDataError(w, method, "user id") return @@ -464,18 +481,18 @@ func (h *Handler) ChangeUserPassword(w http.ResponseWriter, r *http.Request) { newPassword := string(body) // can only update your own account unless you are an admin - if userID != s.Context.UserID && !s.Context.Administrator { + if userID != ctx.UserID || !ctx.Administrator { response.WriteForbiddenError(w) return } - s.Context.Transaction, err = h.Runtime.Db.Beginx() + ctx.Transaction, err = h.Runtime.Db.Beginx() if err != nil { response.WriteServerError(w, method, err) return } - u, err := Get(s, userID) + u, err := h.Store.User.Get(ctx, userID) if err != nil { response.WriteServerError(w, method, err) return @@ -483,28 +500,29 @@ func (h *Handler) ChangeUserPassword(w http.ResponseWriter, r *http.Request) { u.Salt = secrets.GenerateSalt() - err = UpdateUserPassword(s, userID, user.Salt, secrets.GeneratePassword(newPassword, user.Salt)) + err = h.Store.User.UpdateUserPassword(ctx, userID, u.Salt, secrets.GeneratePassword(newPassword, u.Salt)) if err != nil { response.WriteServerError(w, method, err) return } - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() + response.WriteEmpty(w) } -// GetUserFolderPermissions returns folder permission for authenticated user. -func (h *Handler) GetUserFolderPermissions(w http.ResponseWriter, r *http.Request) { - method := "user.ChangeUserPassword" - s := domain.NewContext(h.Runtime, r) +// UserSpacePermissions returns folder permission for authenticated user. +func (h *Handler) UserSpacePermissions(w http.ResponseWriter, r *http.Request) { + method := "user.UserSpacePermissions" + ctx := domain.GetRequestContext(r) - userID := request.Param("userID") - if userID != p.Context.UserID { + userID := request.Param(r, "userID") + if userID != ctx.UserID { response.WriteForbiddenError(w) return } - roles, err := space.GetUserLabelRoles(s, userID) + roles, err := h.Store.Space.GetUserRoles(ctx) if err == sql.ErrNoRows { err = nil roles = []space.Role{} @@ -517,12 +535,12 @@ func (h *Handler) GetUserFolderPermissions(w http.ResponseWriter, r *http.Reques response.WriteJSON(w, roles) } -// ForgotUserPassword initiates the change password procedure. +// ForgotPassword initiates the change password procedure. // Generates a reset token and sends email to the user. // User has to click link in email and then provide a new password. -func (h *Handler) ForgotUserPassword(w http.ResponseWriter, r *http.Request) { - method := "user.ForgotUserPassword" - s := domain.NewContext(h.Runtime, r) +func (h *Handler) ForgotPassword(w http.ResponseWriter, r *http.Request) { + method := "user.ForgotPassword" + ctx := domain.GetRequestContext(r) defer streamutil.Close(r.Body) body, err := ioutil.ReadAll(r.Body) @@ -531,14 +549,14 @@ func (h *Handler) ForgotUserPassword(w http.ResponseWriter, r *http.Request) { return } - u := new(User) + u := new(user.User) err = json.Unmarshal(body, &u) if err != nil { response.WriteBadRequestError(w, method, "JSON body") return } - s.Context.Transaction, err = request.Db.Beginx() + ctx.Transaction, err = h.Runtime.Db.Beginx() if err != nil { response.WriteServerError(w, method, err) return @@ -546,33 +564,33 @@ func (h *Handler) ForgotUserPassword(w http.ResponseWriter, r *http.Request) { token := secrets.GenerateSalt() - err = ForgotUserPassword(s, u.Email, token) + err = h.Store.User.ForgotUserPassword(ctx, u.Email, token) if err != nil && err != sql.ErrNoRows { - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() response.WriteServerError(w, method, err) return } if err == sql.ErrNoRows { response.WriteEmpty(w) - h.Runtime.Log.Info(fmt.Errorf("User %s not found for password reset process", u.Email)) + h.Runtime.Log.Info(fmt.Sprintf("User %s not found for password reset process", u.Email)) return } - s.Context.Transaction.Commit() + ctx.Transaction.Commit() - appURL := s.Context.GetAppURL(fmt.Sprintf("auth/reset/%s", token)) + appURL := ctx.GetAppURL(fmt.Sprintf("auth/reset/%s", token)) go mail.PasswordReset(u.Email, appURL) response.WriteEmpty(w) } -// ResetUserPassword stores the newly chosen password for the user. -func (h *Handler) ResetUserPassword(w http.ResponseWriter, r *http.Request) { +// ResetPassword stores the newly chosen password for the user. +func (h *Handler) ResetPassword(w http.ResponseWriter, r *http.Request) { method := "user.ForgotUserPassword" - s := domain.NewContext(h.Runtime, r) + ctx := domain.GetRequestContext(r) - token := request.Param("token") + token := request.Param(r, "token") if len(token) == 0 { response.WriteMissingDataError(w, method, "missing token") return @@ -586,31 +604,30 @@ func (h *Handler) ResetUserPassword(w http.ResponseWriter, r *http.Request) { } newPassword := string(body) - s.Context.Transaction, err = h.Runtime.Db.Beginx() + ctx.Transaction, err = h.Runtime.Db.Beginx() if err != nil { response.WriteServerError(w, method, err) return } - u, err := GetByToken(token) + u, err := h.Store.User.GetByToken(ctx, token) if err != nil { response.WriteServerError(w, method, err) return } - user.Salt = secrets.GenerateSalt() + u.Salt = secrets.GenerateSalt() - err = UpdateUserPassword(s, u.RefID, u.Salt, secrets.GeneratePassword(newPassword, u.Salt)) + err = h.Store.User.UpdateUserPassword(ctx, u.RefID, u.Salt, secrets.GeneratePassword(newPassword, u.Salt)) if err != nil { - s.Context.Transaction.Rollback() + ctx.Transaction.Rollback() response.WriteServerError(w, method, err) return } - eventing.Record(s, eventing.EventTypeUserPasswordReset) + h.Store.Audit.Record(ctx, audit.EventTypeUserPasswordReset) - s.Context.Transaction.Commit() + ctx.Transaction.Commit() response.WriteEmpty(w) } -*/ diff --git a/edition/boot/store.go b/edition/boot/store.go index 60c209b2..c5b71983 100644 --- a/edition/boot/store.go +++ b/edition/boot/store.go @@ -16,12 +16,16 @@ import ( "github.com/documize/community/core/env" "github.com/documize/community/domain" account "github.com/documize/community/domain/account/mysql" + attachment "github.com/documize/community/domain/attachment/mysql" audit "github.com/documize/community/domain/audit/mysql" + doc "github.com/documize/community/domain/document/mysql" + link "github.com/documize/community/domain/link/mysql" org "github.com/documize/community/domain/organization/mysql" + page "github.com/documize/community/domain/page/mysql" pin "github.com/documize/community/domain/pin/mysql" + setting "github.com/documize/community/domain/setting/mysql" space "github.com/documize/community/domain/space/mysql" user "github.com/documize/community/domain/user/mysql" - doc "github.com/documize/community/domain/document/mysql" ) // AttachStore selects database persistence layer @@ -33,4 +37,10 @@ func AttachStore(r *env.Runtime, s *domain.Store) { s.Pin = pin.Scope{Runtime: r} s.Audit = audit.Scope{Runtime: r} s.Document = doc.Scope{Runtime: r} + s.Setting = setting.Scope{Runtime: r} + s.Attachment = attachment.Scope{Runtime: r} + s.Link = link.Scope{Runtime: r} + s.Page = page.Scope{Runtime: r} } + +// https://github.com/golang-sql/sqlexp/blob/c2488a8be21d20d31abf0d05c2735efd2d09afe4/quoter.go#L46 diff --git a/model/activity/activity.go b/model/activity/activity.go new file mode 100644 index 00000000..c989b598 --- /dev/null +++ b/model/activity/activity.go @@ -0,0 +1,73 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package activity + +import "time" + +// UserActivity represents an activity undertaken by a user. +type UserActivity struct { + ID uint64 `json:"-"` + OrgID string `json:"orgId"` + UserID string `json:"userId"` + LabelID string `json:"folderId"` + SourceID string `json:"sourceId"` + SourceName string `json:"sourceName"` // e.g. Document or Space name + SourceType SourceType `json:"sourceType"` + ActivityType Type `json:"activityType"` + Created time.Time `json:"created"` +} + +// SourceType details where the activity occured. +type SourceType int + +// Type determines type of user activity +type Type int + +const ( + // SourceTypeSpace indicates activity against a space. + SourceTypeSpace SourceType = 1 + + // SourceTypeDocument indicates activity against a document. + SourceTypeDocument SourceType = 2 +) + +const ( + // TypeCreated records user document creation + TypeCreated Type = 1 + + // TypeRead states user has read document + TypeRead Type = 2 + + // TypeEdited states user has editing document + TypeEdited Type = 3 + + // TypeDeleted records user deleting space/document + TypeDeleted Type = 4 + + // TypeArchived records user archiving space/document + TypeArchived Type = 5 + + // TypeApproved records user approval of document + TypeApproved Type = 6 + + // TypeReverted records user content roll-back to previous version + TypeReverted Type = 7 + + // TypePublishedTemplate records user creating new document template + TypePublishedTemplate Type = 8 + + // TypePublishedBlock records user creating reusable content block + TypePublishedBlock Type = 9 + + // TypeFeedback records user providing document feedback + TypeFeedback Type = 10 +) diff --git a/model/attachment/attachment.go b/model/attachment/attachment.go new file mode 100644 index 00000000..a5f7b967 --- /dev/null +++ b/model/attachment/attachment.go @@ -0,0 +1,26 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package attachment + +import "github.com/documize/community/model" + +// Attachment represents an attachment to a document. +type Attachment struct { + model.BaseEntity + OrgID string `json:"orgId"` + DocumentID string `json:"documentId"` + Job string `json:"job"` + FileID string `json:"fileId"` + Filename string `json:"filename"` + Data []byte `json:"-"` + Extension string `json:"extension"` +} diff --git a/domain/auth/model.go b/model/auth/auth.go similarity index 82% rename from domain/auth/model.go rename to model/auth/auth.go index 4e0f02ef..018b9c8f 100644 --- a/domain/auth/model.go +++ b/model/auth/auth.go @@ -11,15 +11,10 @@ package auth -/* -// Handler contains the runtime information such as logging and database. -type Handler struct { - Runtime *env.Runtime -} +import "github.com/documize/community/model/user" // AuthenticationModel details authentication token and user details. type AuthenticationModel struct { Token string `json:"token"` User user.User `json:"user"` } -*/ diff --git a/model/auth/keycloak.go b/model/auth/keycloak.go new file mode 100644 index 00000000..a075a3aa --- /dev/null +++ b/model/auth/keycloak.go @@ -0,0 +1,52 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package auth + +// KeycloakAuthRequest data received via Keycloak client library +type KeycloakAuthRequest struct { + Domain string `json:"domain"` + Token string `json:"token"` + RemoteID string `json:"remoteId"` + Email string `json:"email"` + Username string `json:"username"` + Firstname string `json:"firstname"` + Lastname string `json:"lastname"` + Enabled bool `json:"enabled"` +} + +// KeycloakConfig server configuration +type KeycloakConfig struct { + URL string `json:"url"` + Realm string `json:"realm"` + ClientID string `json:"clientId"` + PublicKey string `json:"publicKey"` + AdminUser string `json:"adminUser"` + AdminPassword string `json:"adminPassword"` + Group string `json:"group"` + DisableLogout bool `json:"disableLogout"` + DefaultPermissionAddSpace bool `json:"defaultPermissionAddSpace"` +} + +// KeycloakAPIAuth is returned when authenticating with Keycloak REST API. +type KeycloakAPIAuth struct { + AccessToken string `json:"access_token"` +} + +// KeycloakUser details user record returned by Keycloak +type KeycloakUser struct { + ID string `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + Firstname string `json:"firstName"` + Lastname string `json:"lastName"` + Enabled bool `json:"enabled"` +} diff --git a/model/doc/doc.go b/model/doc/doc.go new file mode 100644 index 00000000..e80d5f11 --- /dev/null +++ b/model/doc/doc.go @@ -0,0 +1,82 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package doc + +import ( + "strings" + "time" + + "github.com/documize/community/model" +) + +// Document represents the purpose of Documize. +type Document struct { + model.BaseEntity + OrgID string `json:"orgId"` + LabelID string `json:"folderId"` + UserID string `json:"userId"` + Job string `json:"job"` + Location string `json:"location"` + Title string `json:"name"` + Excerpt string `json:"excerpt"` + Slug string `json:"-"` + Tags string `json:"tags"` + Template bool `json:"template"` + Layout string `json:"layout"` +} + +// SetDefaults ensures on blanks and cleans. +func (d *Document) SetDefaults() { + d.Title = strings.TrimSpace(d.Title) + + if len(d.Title) == 0 { + d.Title = "Document" + } +} + +// DocumentMeta details who viewed the document. +type DocumentMeta struct { + Viewers []DocumentMetaViewer `json:"viewers"` + Editors []DocumentMetaEditor `json:"editors"` +} + +// DocumentMetaViewer contains the "view" metatdata content. +type DocumentMetaViewer struct { + UserID string `json:"userId"` + Created time.Time `json:"created"` + Firstname string `json:"firstname"` + Lastname string `json:"lastname"` +} + +// DocumentMetaEditor contains the "edit" metatdata content. +type DocumentMetaEditor struct { + PageID string `json:"pageId"` + UserID string `json:"userId"` + Action string `json:"action"` + Created time.Time `json:"created"` + Firstname string `json:"firstname"` + Lastname string `json:"lastname"` +} + +// UploadModel details the job ID of an uploaded document. +type UploadModel struct { + JobID string `json:"jobId"` +} + +// SitemapDocument details a document that can be exposed via Sitemap. +type SitemapDocument struct { + DocumentID string + Document string + FolderID string + Folder string + Revised time.Time +} diff --git a/model/link/link.go b/model/link/link.go new file mode 100644 index 00000000..1af326d2 --- /dev/null +++ b/model/link/link.go @@ -0,0 +1,39 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package link + +import "github.com/documize/community/model" + +// Link defines a reference between a section and another document/section/attachment. +type Link struct { + model.BaseEntity + OrgID string `json:"orgId"` + FolderID string `json:"folderId"` + UserID string `json:"userId"` + LinkType string `json:"linkType"` + SourceDocumentID string `json:"sourceDocumentId"` + SourcePageID string `json:"sourcePageId"` + TargetDocumentID string `json:"targetDocumentId"` + TargetID string `json:"targetId"` + Orphan bool `json:"orphan"` +} + +// Candidate defines a potential link to a document/section/attachment. +type Candidate struct { + RefID string `json:"id"` + LinkType string `json:"linkType"` + FolderID string `json:"folderId"` + DocumentID string `json:"documentId"` + TargetID string `json:"targetId"` + Title string `json:"title"` // what we label the link + Context string `json:"context"` // additional context (e.g. excerpt, parent, file extension) +} diff --git a/model/org/meta.go b/model/org/meta.go new file mode 100644 index 00000000..c5815dcb --- /dev/null +++ b/model/org/meta.go @@ -0,0 +1,38 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package org + +import "time" + +// SitemapDocument details a document that can be exposed via Sitemap. +type SitemapDocument struct { + DocumentID string + Document string + FolderID string + Folder string + Revised time.Time +} + +// SiteMeta holds information associated with an Organization. +type SiteMeta struct { + OrgID string `json:"orgId"` + Title string `json:"title"` + Message string `json:"message"` + URL string `json:"url"` + AllowAnonymousAccess bool `json:"allowAnonymousAccess"` + AuthProvider string `json:"authProvider"` + AuthConfig string `json:"authConfig"` + Version string `json:"version"` + Edition string `json:"edition"` + Valid bool `json:"valid"` + ConversionEndpoint string `json:"conversionEndpoint"` +} diff --git a/model/page/page.go b/model/page/page.go new file mode 100644 index 00000000..3663827b --- /dev/null +++ b/model/page/page.go @@ -0,0 +1,116 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package page + +import ( + "strings" + "time" + + "github.com/documize/community/model" +) + +// Page represents a section within a document. +type Page struct { + model.BaseEntity + OrgID string `json:"orgId"` + DocumentID string `json:"documentId"` + UserID string `json:"userId"` + ContentType string `json:"contentType"` + PageType string `json:"pageType"` + BlockID string `json:"blockId"` + Level uint64 `json:"level"` + Sequence float64 `json:"sequence"` + Title string `json:"title"` + Body string `json:"body"` + Revisions uint64 `json:"revisions"` +} + +// SetDefaults ensures no blank values. +func (p *Page) SetDefaults() { + if len(p.ContentType) == 0 { + p.ContentType = "wysiwyg" + } + + p.Title = strings.TrimSpace(p.Title) +} + +// IsSectionType tells us that page is "words" +func (p *Page) IsSectionType() bool { + return p.PageType == "section" +} + +// IsTabType tells us that page is "SaaS data embed" +func (p *Page) IsTabType() bool { + return p.PageType == "tab" +} + +// Meta holds raw page data that is used to +// render the actual page data. +type Meta struct { + ID uint64 `json:"id"` + Created time.Time `json:"created"` + Revised time.Time `json:"revised"` + OrgID string `json:"orgId"` + UserID string `json:"userId"` + DocumentID string `json:"documentId"` + PageID string `json:"pageId"` + RawBody string `json:"rawBody"` // a blob of data + Config string `json:"config"` // JSON based custom config for this type + ExternalSource bool `json:"externalSource"` // true indicates data sourced externally +} + +// SetDefaults ensures no blank values. +func (p *Meta) SetDefaults() { + if len(p.Config) == 0 { + p.Config = "{}" + } +} + +// Revision holds the previous version of a Page. +type Revision struct { + model.BaseEntity + OrgID string `json:"orgId"` + DocumentID string `json:"documentId"` + PageID string `json:"pageId"` + OwnerID string `json:"ownerId"` + UserID string `json:"userId"` + ContentType string `json:"contentType"` + PageType string `json:"pageType"` + Title string `json:"title"` + Body string `json:"body"` + RawBody string `json:"rawBody"` + Config string `json:"config"` + Email string `json:"email"` + Firstname string `json:"firstname"` + Lastname string `json:"lastname"` + Initials string `json:"initials"` + Revisions int `json:"revisions"` +} + +// Block represents a section that has been published as a reusable content block. +type Block struct { + model.BaseEntity + OrgID string `json:"orgId"` + LabelID string `json:"folderId"` + UserID string `json:"userId"` + ContentType string `json:"contentType"` + PageType string `json:"pageType"` + Title string `json:"title"` + Body string `json:"body"` + Excerpt string `json:"excerpt"` + RawBody string `json:"rawBody"` // a blob of data + Config string `json:"config"` // JSON based custom config for this type + ExternalSource bool `json:"externalSource"` // true indicates data sourced externally + Used uint64 `json:"used"` + Firstname string `json:"firstname"` + Lastname string `json:"lastname"` +} diff --git a/model/search/search.go b/model/search/search.go new file mode 100644 index 00000000..d52dd4a5 --- /dev/null +++ b/model/search/search.go @@ -0,0 +1,43 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package search + +import "time" + +// Search holds raw search results. +type Search struct { + ID string `json:"id"` + Created time.Time `json:"created"` + Revised time.Time `json:"revised"` + OrgID string + DocumentID string + Level uint64 + Sequence float64 + DocumentTitle string + Slug string + PageTitle string + Body string +} + +// DocumentSearch represents 'presentable' search results. +type DocumentSearch struct { + ID string `json:"id"` + DocumentID string `json:"documentId"` + DocumentTitle string `json:"documentTitle"` + DocumentSlug string `json:"documentSlug"` + DocumentExcerpt string `json:"documentExcerpt"` + Tags string `json:"documentTags"` + PageTitle string `json:"pageTitle"` + LabelID string `json:"folderId"` + LabelName string `json:"folderName"` + FolderSlug string `json:"folderSlug"` +} diff --git a/model/template/template.go b/model/template/template.go new file mode 100644 index 00000000..9cb4ac25 --- /dev/null +++ b/model/template/template.go @@ -0,0 +1,54 @@ +// Copyright 2016 Documize Inc. . All rights reserved. +// +// This software (Documize Community Edition) is licensed under +// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html +// +// You can operate outside the AGPL restrictions by purchasing +// Documize Enterprise Edition and obtaining a commercial license +// by contacting . +// +// https://documize.com + +package template + +import "time" + +// Template is used to create a new document. +// Template can consist of content, attachments and +// have associated meta data indentifying author, version +// contact details and more. +type Template struct { + ID string `json:"id"` + Title string `json:"title"` + Description string `json:"description"` + Author string `json:"author"` + Type Type `json:"type"` + Dated time.Time `json:"dated"` +} + +// Type determines who can see a template. +type Type int + +const ( + // TypePublic means anyone can see the template. + TypePublic Type = 1 + // TypePrivate means only the owner can see the template. + TypePrivate Type = 2 + // TypeRestricted means selected users can see the template. + TypeRestricted Type = 3 +) + +// IsPublic means anyone can see the template. +func (t *Template) IsPublic() bool { + return t.Type == TypePublic +} + +// IsPrivate means only the owner can see the template. +func (t *Template) IsPrivate() bool { + return t.Type == TypePrivate +} + +// IsRestricted means selected users can see the template. +func (t *Template) IsRestricted() bool { + return t.Type == TypeRestricted +} diff --git a/server/routing/entries.go b/server/routing/entries.go index 0da266be..8de00f6a 100644 --- a/server/routing/entries.go +++ b/server/routing/entries.go @@ -14,31 +14,48 @@ package routing import ( "net/http" - "github.com/documize/community/core/api" "github.com/documize/community/core/api/endpoint" "github.com/documize/community/core/env" "github.com/documize/community/domain" + "github.com/documize/community/domain/attachment" + "github.com/documize/community/domain/auth" + "github.com/documize/community/domain/link" + "github.com/documize/community/domain/meta" "github.com/documize/community/domain/organization" "github.com/documize/community/domain/pin" + "github.com/documize/community/domain/setting" "github.com/documize/community/domain/space" + "github.com/documize/community/domain/user" "github.com/documize/community/server/web" ) // RegisterEndpoints register routes for serving API endpoints func RegisterEndpoints(rt *env.Runtime, s *domain.Store) { + // We pass server/application level contextual requirements into HTTP handlers + // DO NOT pass in per request context (that is done by auth middleware per request) + pin := pin.Handler{Runtime: rt, Store: s} + auth := auth.Handler{Runtime: rt, Store: s} + meta := meta.Handler{Runtime: rt, Store: s} + user := user.Handler{Runtime: rt, Store: s} + link := link.Handler{Runtime: rt, Store: s} + space := space.Handler{Runtime: rt, Store: s} + setting := setting.Handler{Runtime: rt, Store: s} + attachment := attachment.Handler{Runtime: rt, Store: s} + organization := organization.Handler{Runtime: rt, Store: s} + //************************************************** // Non-secure routes //************************************************** - Add(rt, RoutePrefixPublic, "meta", []string{"GET", "OPTIONS"}, nil, endpoint.GetMeta) + Add(rt, RoutePrefixPublic, "meta", []string{"GET", "OPTIONS"}, nil, meta.Meta) Add(rt, RoutePrefixPublic, "authenticate/keycloak", []string{"POST", "OPTIONS"}, nil, endpoint.AuthenticateKeycloak) - Add(rt, RoutePrefixPublic, "authenticate", []string{"POST", "OPTIONS"}, nil, endpoint.Authenticate) - Add(rt, RoutePrefixPublic, "validate", []string{"GET", "OPTIONS"}, nil, endpoint.ValidateAuthToken) - Add(rt, RoutePrefixPublic, "forgot", []string{"POST", "OPTIONS"}, nil, endpoint.ForgotUserPassword) - Add(rt, RoutePrefixPublic, "reset/{token}", []string{"POST", "OPTIONS"}, nil, endpoint.ResetUserPassword) - Add(rt, RoutePrefixPublic, "share/{folderID}", []string{"POST", "OPTIONS"}, nil, endpoint.AcceptSharedFolder) - Add(rt, RoutePrefixPublic, "attachments/{orgID}/{attachmentID}", []string{"GET", "OPTIONS"}, nil, endpoint.AttachmentDownload) + Add(rt, RoutePrefixPublic, "authenticate", []string{"POST", "OPTIONS"}, nil, auth.Login) + Add(rt, RoutePrefixPublic, "validate", []string{"GET", "OPTIONS"}, nil, auth.ValidateToken) + Add(rt, RoutePrefixPublic, "forgot", []string{"POST", "OPTIONS"}, nil, user.ForgotPassword) + Add(rt, RoutePrefixPublic, "reset/{token}", []string{"POST", "OPTIONS"}, nil, user.ResetPassword) + Add(rt, RoutePrefixPublic, "share/{folderID}", []string{"POST", "OPTIONS"}, nil, space.AcceptInvitation) + Add(rt, RoutePrefixPublic, "attachments/{orgID}/{attachmentID}", []string{"GET", "OPTIONS"}, nil, attachment.Download) Add(rt, RoutePrefixPublic, "version", []string{"GET", "OPTIONS"}, nil, func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(api.Runtime.Product.Version)) + w.Write([]byte(rt.Product.Version)) }) //************************************************** @@ -48,7 +65,6 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) { // Import & Convert Document Add(rt, RoutePrefixPrivate, "import/folder/{folderID}", []string{"POST", "OPTIONS"}, nil, endpoint.UploadConvertDocument) - // Document Add(rt, RoutePrefixPrivate, "documents/{documentID}/export", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentAsDocx) Add(rt, RoutePrefixPrivate, "documents", []string{"GET", "OPTIONS"}, []string{"filter", "tag"}, endpoint.GetDocumentsByTag) Add(rt, RoutePrefixPrivate, "documents", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentsByFolder) @@ -57,7 +73,6 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) { Add(rt, RoutePrefixPrivate, "documents/{documentID}", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteDocument) Add(rt, RoutePrefixPrivate, "documents/{documentID}/activity", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentActivity) - // Document Page Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/level", []string{"POST", "OPTIONS"}, nil, endpoint.ChangeDocumentPageLevel) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/sequence", []string{"POST", "OPTIONS"}, nil, endpoint.ChangeDocumentPageSequence) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/batch", []string{"POST", "OPTIONS"}, nil, endpoint.GetDocumentPagesBatch) @@ -72,19 +87,15 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) { Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteDocumentPages) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/{pageID}", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentPage) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages", []string{"POST", "OPTIONS"}, nil, endpoint.AddDocumentPage) - Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments", []string{"GET", "OPTIONS"}, nil, endpoint.GetAttachments) - Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments/{attachmentID}", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteAttachment) - Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments", []string{"POST", "OPTIONS"}, nil, endpoint.AddAttachments) + Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments", []string{"GET", "OPTIONS"}, nil, attachment.Get) + Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments/{attachmentID}", []string{"DELETE", "OPTIONS"}, nil, attachment.Delete) + Add(rt, RoutePrefixPrivate, "documents/{documentID}/attachments", []string{"POST", "OPTIONS"}, nil, attachment.Add) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/{pageID}/meta", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentPageMeta) Add(rt, RoutePrefixPrivate, "documents/{documentID}/pages/{pageID}/copy/{targetID}", []string{"POST", "OPTIONS"}, nil, endpoint.CopyPage) - // Organization - organization := organization.Handler{Runtime: rt, Store: s} Add(rt, RoutePrefixPrivate, "organizations/{orgID}", []string{"GET", "OPTIONS"}, nil, organization.Get) Add(rt, RoutePrefixPrivate, "organizations/{orgID}", []string{"PUT", "OPTIONS"}, nil, organization.Update) - // Space - space := space.Handler{Runtime: rt, Store: s} Add(rt, RoutePrefixPrivate, "folders/{folderID}", []string{"DELETE", "OPTIONS"}, nil, space.Delete) Add(rt, RoutePrefixPrivate, "folders/{folderID}/move/{moveToId}", []string{"DELETE", "OPTIONS"}, nil, space.Remove) Add(rt, RoutePrefixPrivate, "folders/{folderID}/permissions", []string{"PUT", "OPTIONS"}, nil, space.SetPermissions) @@ -96,28 +107,24 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) { Add(rt, RoutePrefixPrivate, "folders/{folderID}", []string{"GET", "OPTIONS"}, nil, space.Get) Add(rt, RoutePrefixPrivate, "folders/{folderID}", []string{"PUT", "OPTIONS"}, nil, space.Update) - // Users - Add(rt, RoutePrefixPrivate, "users/{userID}/password", []string{"POST", "OPTIONS"}, nil, endpoint.ChangeUserPassword) - Add(rt, RoutePrefixPrivate, "users/{userID}/permissions", []string{"GET", "OPTIONS"}, nil, endpoint.GetUserFolderPermissions) - Add(rt, RoutePrefixPrivate, "users", []string{"POST", "OPTIONS"}, nil, endpoint.AddUser) - Add(rt, RoutePrefixPrivate, "users/folder/{folderID}", []string{"GET", "OPTIONS"}, nil, endpoint.GetFolderUsers) - Add(rt, RoutePrefixPrivate, "users", []string{"GET", "OPTIONS"}, nil, endpoint.GetOrganizationUsers) - Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"GET", "OPTIONS"}, nil, endpoint.GetUser) - Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"PUT", "OPTIONS"}, nil, endpoint.UpdateUser) - Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"DELETE", "OPTIONS"}, nil, endpoint.DeleteUser) + Add(rt, RoutePrefixPrivate, "users/{userID}/password", []string{"POST", "OPTIONS"}, nil, user.ChangePassword) + Add(rt, RoutePrefixPrivate, "users/{userID}/permissions", []string{"GET", "OPTIONS"}, nil, user.UserSpacePermissions) + Add(rt, RoutePrefixPrivate, "users", []string{"POST", "OPTIONS"}, nil, user.Add) + Add(rt, RoutePrefixPrivate, "users/folder/{folderID}", []string{"GET", "OPTIONS"}, nil, user.GetSpaceUsers) + Add(rt, RoutePrefixPrivate, "users", []string{"GET", "OPTIONS"}, nil, user.GetOrganizationUsers) + Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"GET", "OPTIONS"}, nil, user.Get) + Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"PUT", "OPTIONS"}, nil, user.Update) + Add(rt, RoutePrefixPrivate, "users/{userID}", []string{"DELETE", "OPTIONS"}, nil, user.Delete) Add(rt, RoutePrefixPrivate, "users/sync", []string{"GET", "OPTIONS"}, nil, endpoint.SyncKeycloak) - // Search Add(rt, RoutePrefixPrivate, "search", []string{"GET", "OPTIONS"}, nil, endpoint.SearchDocuments) - // Templates Add(rt, RoutePrefixPrivate, "templates", []string{"POST", "OPTIONS"}, nil, endpoint.SaveAsTemplate) Add(rt, RoutePrefixPrivate, "templates", []string{"GET", "OPTIONS"}, nil, endpoint.GetSavedTemplates) Add(rt, RoutePrefixPrivate, "templates/stock", []string{"GET", "OPTIONS"}, nil, endpoint.GetStockTemplates) Add(rt, RoutePrefixPrivate, "templates/{templateID}/folder/{folderID}", []string{"POST", "OPTIONS"}, []string{"type", "stock"}, endpoint.StartDocumentFromStockTemplate) Add(rt, RoutePrefixPrivate, "templates/{templateID}/folder/{folderID}", []string{"POST", "OPTIONS"}, []string{"type", "saved"}, endpoint.StartDocumentFromSavedTemplate) - // Sections Add(rt, RoutePrefixPrivate, "sections", []string{"GET", "OPTIONS"}, nil, endpoint.GetSections) Add(rt, RoutePrefixPrivate, "sections", []string{"POST", "OPTIONS"}, nil, endpoint.RunSectionCommand) Add(rt, RoutePrefixPrivate, "sections/refresh", []string{"GET", "OPTIONS"}, nil, endpoint.RefreshSections) @@ -128,28 +135,23 @@ func RegisterEndpoints(rt *env.Runtime, s *domain.Store) { Add(rt, RoutePrefixPrivate, "sections/blocks", []string{"POST", "OPTIONS"}, nil, endpoint.AddBlock) Add(rt, RoutePrefixPrivate, "sections/targets", []string{"GET", "OPTIONS"}, nil, endpoint.GetPageMoveCopyTargets) - // Links - Add(rt, RoutePrefixPrivate, "links/{folderID}/{documentID}/{pageID}", []string{"GET", "OPTIONS"}, nil, endpoint.GetLinkCandidates) - Add(rt, RoutePrefixPrivate, "links", []string{"GET", "OPTIONS"}, nil, endpoint.SearchLinkCandidates) + Add(rt, RoutePrefixPrivate, "links/{folderID}/{documentID}/{pageID}", []string{"GET", "OPTIONS"}, nil, link.GetLinkCandidates) + Add(rt, RoutePrefixPrivate, "links", []string{"GET", "OPTIONS"}, nil, link.SearchLinkCandidates) Add(rt, RoutePrefixPrivate, "documents/{documentID}/links", []string{"GET", "OPTIONS"}, nil, endpoint.GetDocumentLinks) - // Global installation-wide config - Add(rt, RoutePrefixPrivate, "global/smtp", []string{"GET", "OPTIONS"}, nil, endpoint.GetSMTPConfig) - Add(rt, RoutePrefixPrivate, "global/smtp", []string{"PUT", "OPTIONS"}, nil, endpoint.SaveSMTPConfig) - Add(rt, RoutePrefixPrivate, "global/license", []string{"GET", "OPTIONS"}, nil, endpoint.GetLicense) - Add(rt, RoutePrefixPrivate, "global/license", []string{"PUT", "OPTIONS"}, nil, endpoint.SaveLicense) - Add(rt, RoutePrefixPrivate, "global/auth", []string{"GET", "OPTIONS"}, nil, endpoint.GetAuthConfig) - Add(rt, RoutePrefixPrivate, "global/auth", []string{"PUT", "OPTIONS"}, nil, endpoint.SaveAuthConfig) + Add(rt, RoutePrefixPrivate, "global/smtp", []string{"GET", "OPTIONS"}, nil, setting.SMTP) + Add(rt, RoutePrefixPrivate, "global/smtp", []string{"PUT", "OPTIONS"}, nil, setting.SetSMTP) + Add(rt, RoutePrefixPrivate, "global/license", []string{"GET", "OPTIONS"}, nil, setting.License) + Add(rt, RoutePrefixPrivate, "global/license", []string{"PUT", "OPTIONS"}, nil, setting.SetLicense) + Add(rt, RoutePrefixPrivate, "global/auth", []string{"GET", "OPTIONS"}, nil, setting.AuthConfig) + Add(rt, RoutePrefixPrivate, "global/auth", []string{"PUT", "OPTIONS"}, nil, setting.SetAuthConfig) - // Pinned items - pin := pin.Handler{Runtime: rt, Store: s} Add(rt, RoutePrefixPrivate, "pin/{userID}", []string{"POST", "OPTIONS"}, nil, pin.Add) Add(rt, RoutePrefixPrivate, "pin/{userID}", []string{"GET", "OPTIONS"}, nil, pin.GetUserPins) Add(rt, RoutePrefixPrivate, "pin/{userID}/sequence", []string{"POST", "OPTIONS"}, nil, pin.UpdatePinSequence) Add(rt, RoutePrefixPrivate, "pin/{userID}/{pinID}", []string{"DELETE", "OPTIONS"}, nil, pin.DeleteUserPin) - // Single page app handler - Add(rt, RoutePrefixRoot, "robots.txt", []string{"GET", "OPTIONS"}, nil, endpoint.GetRobots) - Add(rt, RoutePrefixRoot, "sitemap.xml", []string{"GET", "OPTIONS"}, nil, endpoint.GetSitemap) + Add(rt, RoutePrefixRoot, "robots.txt", []string{"GET", "OPTIONS"}, nil, meta.RobotsTxt) + Add(rt, RoutePrefixRoot, "sitemap.xml", []string{"GET", "OPTIONS"}, nil, meta.Sitemap) Add(rt, RoutePrefixRoot, "{rest:.*}", nil, nil, web.EmberHandler) } diff --git a/vendor/github.com/jmoiron/sqlx/README.md b/vendor/github.com/jmoiron/sqlx/README.md index 55155d00..5c1bb3cb 100644 --- a/vendor/github.com/jmoiron/sqlx/README.md +++ b/vendor/github.com/jmoiron/sqlx/README.md @@ -1,4 +1,4 @@ -#sqlx +# sqlx [![Build Status](https://drone.io/github.com/jmoiron/sqlx/status.png)](https://drone.io/github.com/jmoiron/sqlx/latest) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/sqlx) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/jmoiron/sqlx/master/LICENSE) @@ -26,9 +26,7 @@ This breaks backwards compatibility, but it's in a way that is trivially fixable (`s/JsonText/JSONText/g`). The `types` package is both experimental and not in active development currently. -More importantly, [golang bug #13905](https://github.com/golang/go/issues/13905) -makes `types.JSONText` and `types.GzippedText` _potentially unsafe_, **especially** -when used with common auto-scan sqlx idioms like `Select` and `Get`. +* Using Go 1.6 and below with `types.JSONText` and `types.GzippedText` can be _potentially unsafe_, **especially** when used with common auto-scan sqlx idioms like `Select` and `Get`. See [golang bug #13905](https://github.com/golang/go/issues/13905). ### Backwards Compatibility diff --git a/vendor/github.com/jmoiron/sqlx/bind.go b/vendor/github.com/jmoiron/sqlx/bind.go index 564635ca..10f7bdf8 100644 --- a/vendor/github.com/jmoiron/sqlx/bind.go +++ b/vendor/github.com/jmoiron/sqlx/bind.go @@ -27,7 +27,7 @@ func BindType(driverName string) int { return QUESTION case "sqlite3": return QUESTION - case "oci8": + case "oci8", "ora", "goracle": return NAMED } return UNKNOWN @@ -43,27 +43,28 @@ func Rebind(bindType int, query string) string { return query } - qb := []byte(query) // Add space enough for 10 params before we have to allocate - rqb := make([]byte, 0, len(qb)+10) - j := 1 - for _, b := range qb { - if b == '?' { - switch bindType { - case DOLLAR: - rqb = append(rqb, '$') - case NAMED: - rqb = append(rqb, ':', 'a', 'r', 'g') - } - for _, b := range strconv.Itoa(j) { - rqb = append(rqb, byte(b)) - } - j++ - } else { - rqb = append(rqb, b) + rqb := make([]byte, 0, len(query)+10) + + var i, j int + + for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") { + rqb = append(rqb, query[:i]...) + + switch bindType { + case DOLLAR: + rqb = append(rqb, '$') + case NAMED: + rqb = append(rqb, ':', 'a', 'r', 'g') } + + j++ + rqb = strconv.AppendInt(rqb, int64(j), 10) + + query = query[i+1:] } - return string(rqb) + + return string(append(rqb, query...)) } // Experimental implementation of Rebind which uses a bytes.Buffer. The code is @@ -135,9 +136,9 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { } newArgs := make([]interface{}, 0, flatArgsCount) + buf := bytes.NewBuffer(make([]byte, 0, len(query)+len(", ?")*flatArgsCount)) var arg, offset int - var buf bytes.Buffer for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { if arg >= len(meta) { @@ -163,13 +164,12 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { // write everything up to and including our ? character buf.WriteString(query[:offset+i+1]) - newArgs = append(newArgs, argMeta.v.Index(0).Interface()) - for si := 1; si < argMeta.length; si++ { buf.WriteString(", ?") - newArgs = append(newArgs, argMeta.v.Index(si).Interface()) } + newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length) + // slice the query and reset the offset. this avoids some bookkeeping for // the write after the loop query = query[offset+i+1:] @@ -184,3 +184,24 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { return buf.String(), newArgs, nil } + +func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} { + switch val := v.Interface().(type) { + case []interface{}: + args = append(args, val...) + case []int: + for i := range val { + args = append(args, val[i]) + } + case []string: + for i := range val { + args = append(args, val[i]) + } + default: + for si := 0; si < vlen; si++ { + args = append(args, v.Index(si).Interface()) + } + } + + return args +} diff --git a/vendor/github.com/jmoiron/sqlx/named.go b/vendor/github.com/jmoiron/sqlx/named.go index 4df8095d..dd899d35 100644 --- a/vendor/github.com/jmoiron/sqlx/named.go +++ b/vendor/github.com/jmoiron/sqlx/named.go @@ -36,6 +36,7 @@ func (n *NamedStmt) Close() error { } // Exec executes a named statement using the struct passed. +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { @@ -45,6 +46,7 @@ func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { } // Query executes a named statement using the struct argument, returning rows. +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { @@ -56,6 +58,7 @@ func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { // QueryRow executes a named statement against the database. Because sqlx cannot // create a *sql.Row with an error condition pre-set for binding errors, sqlx // returns a *sqlx.Row instead. +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRow(arg interface{}) *Row { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { @@ -65,6 +68,7 @@ func (n *NamedStmt) QueryRow(arg interface{}) *Row { } // MustExec execs a NamedStmt, panicing on error +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) MustExec(arg interface{}) sql.Result { res, err := n.Exec(arg) if err != nil { @@ -74,6 +78,7 @@ func (n *NamedStmt) MustExec(arg interface{}) sql.Result { } // Queryx using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { r, err := n.Query(arg) if err != nil { @@ -84,11 +89,13 @@ func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { // QueryRowx this NamedStmt. Because of limitations with QueryRow, this is // an alias for QueryRow. +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRowx(arg interface{}) *Row { return n.QueryRow(arg) } // Select using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { rows, err := n.Queryx(arg) if err != nil { @@ -100,6 +107,7 @@ func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { } // Get using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { r := n.QueryRowx(arg) return r.scanAny(dest, false) @@ -250,7 +258,7 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e inName = true name = []byte{} // if we're in a name, and this is an allowed character, continue - } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_') && i != last { + } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { // append the byte to the name if we are in a name and not on the last byte name = append(name, b) // if we're in a name and it's not an allowed character, the name is done diff --git a/vendor/github.com/jmoiron/sqlx/named_context.go b/vendor/github.com/jmoiron/sqlx/named_context.go new file mode 100644 index 00000000..9405007e --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/named_context.go @@ -0,0 +1,132 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" +) + +// A union interface of contextPreparer and binder, required to be able to +// prepare named statements with context (as the bindtype must be determined). +type namedPreparerContext interface { + PreparerContext + binder +} + +func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { + bindType := BindType(p.DriverName()) + q, args, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return nil, err + } + stmt, err := PreparexContext(ctx, p, q) + if err != nil { + return nil, err + } + return &NamedStmt{ + QueryString: q, + Params: args, + Stmt: stmt, + }, nil +} + +// ExecContext executes a named statement using the struct passed. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return *new(sql.Result), err + } + return n.Stmt.ExecContext(ctx, args...) +} + +// QueryContext executes a named statement using the struct argument, returning rows. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return nil, err + } + return n.Stmt.QueryContext(ctx, args...) +} + +// QueryRowContext executes a named statement against the database. Because sqlx cannot +// create a *sql.Row with an error condition pre-set for binding errors, sqlx +// returns a *sqlx.Row instead. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return &Row{err: err} + } + return n.Stmt.QueryRowxContext(ctx, args...) +} + +// MustExecContext execs a NamedStmt, panicing on error +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result { + res, err := n.ExecContext(ctx, arg) + if err != nil { + panic(err) + } + return res +} + +// QueryxContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) { + r, err := n.QueryContext(ctx, arg) + if err != nil { + return nil, err + } + return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err +} + +// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is +// an alias for QueryRow. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row { + return n.QueryRowContext(ctx, arg) +} + +// SelectContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error { + rows, err := n.QueryxContext(ctx, arg) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// GetContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error { + r := n.QueryRowxContext(ctx, arg) + return r.scanAny(dest, false) +} + +// NamedQueryContext binds a named query and then runs Query on the result using the +// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with +// map[string]interface{} types. +func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.QueryxContext(ctx, q, args...) +} + +// NamedExecContext uses BindStruct to get a query executable by the driver and +// then runs Exec on the result. Returns an error from the binding +// or the query excution itself. +func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.ExecContext(ctx, q, args...) +} diff --git a/vendor/github.com/jmoiron/sqlx/named_context_test.go b/vendor/github.com/jmoiron/sqlx/named_context_test.go new file mode 100644 index 00000000..87e94ac2 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/named_context_test.go @@ -0,0 +1,136 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" + "testing" +) + +func TestNamedContextQueries(t *testing.T) { + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + test := Test{t} + var ns *NamedStmt + var err error + + ctx := context.Background() + + // Check that invalid preparations fail + ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name") + if err == nil { + t.Error("Expected an error with invalid prepared statement.") + } + + ns, err = db.PrepareNamedContext(ctx, "invalid sql") + if err == nil { + t.Error("Expected an error with invalid prepared statement.") + } + + // Check closing works as anticipated + ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name") + test.Error(err) + err = ns.Close() + test.Error(err) + + ns, err = db.PrepareNamedContext(ctx, ` + SELECT first_name, last_name, email + FROM person WHERE first_name=:first_name AND email=:email`) + test.Error(err) + + // test Queryx w/ uses Query + p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} + + rows, err := ns.QueryxContext(ctx, p) + test.Error(err) + for rows.Next() { + var p2 Person + rows.StructScan(&p2) + if p.FirstName != p2.FirstName { + t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) + } + if p.LastName != p2.LastName { + t.Errorf("got %s, expected %s", p.LastName, p2.LastName) + } + if p.Email != p2.Email { + t.Errorf("got %s, expected %s", p.Email, p2.Email) + } + } + + // test Select + people := make([]Person, 0, 5) + err = ns.SelectContext(ctx, &people, p) + test.Error(err) + + if len(people) != 1 { + t.Errorf("got %d results, expected %d", len(people), 1) + } + if p.FirstName != people[0].FirstName { + t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) + } + if p.LastName != people[0].LastName { + t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) + } + if p.Email != people[0].Email { + t.Errorf("got %s, expected %s", p.Email, people[0].Email) + } + + // test Exec + ns, err = db.PrepareNamedContext(ctx, ` + INSERT INTO person (first_name, last_name, email) + VALUES (:first_name, :last_name, :email)`) + test.Error(err) + + js := Person{ + FirstName: "Julien", + LastName: "Savea", + Email: "jsavea@ab.co.nz", + } + _, err = ns.ExecContext(ctx, js) + test.Error(err) + + // Make sure we can pull him out again + p2 := Person{} + db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) + if p2.Email != js.Email { + t.Errorf("expected %s, got %s", js.Email, p2.Email) + } + + // test Txn NamedStmts + tx := db.MustBeginTx(ctx, nil) + txns := tx.NamedStmtContext(ctx, ns) + + // We're going to add Steven in this txn + sl := Person{ + FirstName: "Steven", + LastName: "Luatua", + Email: "sluatua@ab.co.nz", + } + + _, err = txns.ExecContext(ctx, sl) + test.Error(err) + // then rollback... + tx.Rollback() + // looking for Steven after a rollback should fail + err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + if err != sql.ErrNoRows { + t.Errorf("expected no rows error, got %v", err) + } + + // now do the same, but commit + tx = db.MustBeginTx(ctx, nil) + txns = tx.NamedStmtContext(ctx, ns) + _, err = txns.ExecContext(ctx, sl) + test.Error(err) + tx.Commit() + + // looking for Steven after a Commit should succeed + err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + test.Error(err) + if p2.Email != sl.Email { + t.Errorf("expected %s, got %s", sl.Email, p2.Email) + } + + }) +} diff --git a/vendor/github.com/jmoiron/sqlx/reflectx/README.md b/vendor/github.com/jmoiron/sqlx/reflectx/README.md index 76f1b5df..f01d3d1f 100644 --- a/vendor/github.com/jmoiron/sqlx/reflectx/README.md +++ b/vendor/github.com/jmoiron/sqlx/reflectx/README.md @@ -12,6 +12,6 @@ behavior of standard Go accessors. The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct -tags in the ways that are vital to most marshalers, and they are slow. +tags in the ways that are vital to most marshallers, and they are slow. This reflectx package extends reflect to achieve these goals. diff --git a/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go b/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go index 04d2080e..f2802b80 100644 --- a/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go +++ b/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go @@ -1,5 +1,5 @@ // Package reflectx implements extensions to the standard reflect lib suitable -// for implementing marshaling and unmarshaling packages. The main Mapper type +// for implementing marshalling and unmarshalling packages. The main Mapper type // allows for Go-compatible named attribute access, including accessing embedded // struct attributes and the ability to use functions and struct tags to // customize field names. @@ -7,14 +7,13 @@ package reflectx import ( - "fmt" "reflect" "runtime" "strings" "sync" ) -// A FieldInfo is a collection of metadata about a struct field. +// A FieldInfo is metadata for a struct field. type FieldInfo struct { Index []int Path string @@ -41,7 +40,8 @@ func (f StructMap) GetByPath(path string) *FieldInfo { } // GetByTraversal returns a *FieldInfo for a given integer path. It is -// analogous to reflect.FieldByIndex. +// analogous to reflect.FieldByIndex, but using the cached traversal +// rather than re-executing the reflect machinery each time. func (f StructMap) GetByTraversal(index []int) *FieldInfo { if len(index) == 0 { return nil @@ -58,8 +58,8 @@ func (f StructMap) GetByTraversal(index []int) *FieldInfo { } // Mapper is a general purpose mapper of names to struct fields. A Mapper -// behaves like most marshallers, optionally obeying a field tag for name -// mapping and a function to provide a basic mapping of fields to names. +// behaves like most marshallers in the standard library, obeying a field tag +// for name mapping but also providing a basic transform function. type Mapper struct { cache map[reflect.Type]*StructMap tagName string @@ -68,8 +68,8 @@ type Mapper struct { mutex sync.Mutex } -// NewMapper returns a new mapper which optionally obeys the field tag given -// by tagName. If tagName is the empty string, it is ignored. +// NewMapper returns a new mapper using the tagName as its struct field tag. +// If tagName is the empty string, it is ignored. func NewMapper(tagName string) *Mapper { return &Mapper{ cache: make(map[reflect.Type]*StructMap), @@ -127,7 +127,7 @@ func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { return r } -// FieldByName returns a field by the its mapped name as a reflect.Value. +// FieldByName returns a field by its mapped name as a reflect.Value. // Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. // Returns zero Value if the name is not found. func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { @@ -182,11 +182,12 @@ func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { return r } -// FieldByIndexes returns a value for a particular struct traversal. +// FieldByIndexes returns a value for the field given by the struct traversal +// for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { for _, i := range indexes { v = reflect.Indirect(v).Field(i) - // if this is a pointer, it's possible it is nil + // if this is a pointer and it's nil, allocate a new value and set it if v.Kind() == reflect.Ptr && v.IsNil() { alloc := reflect.New(Deref(v.Type())) v.Set(alloc) @@ -225,13 +226,12 @@ type kinder interface { // mustBe checks a value against a kind, panicing with a reflect.ValueError // if the kind isn't that which is required. func mustBe(v kinder, expected reflect.Kind) { - k := v.Kind() - if k != expected { + if k := v.Kind(); k != expected { panic(&reflect.ValueError{Method: methodName(), Kind: k}) } } -// methodName is returns the caller of the function calling methodName +// methodName returns the caller of the function calling methodName func methodName() string { pc, _, _, _ := runtime.Caller(2) f := runtime.FuncForPC(pc) @@ -257,19 +257,92 @@ func apnd(is []int, i int) []int { return x } +type mapf func(string) string + +// parseName parses the tag and the target name for the given field using +// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the +// field's name to a target name, and tagMapFunc for mapping the tag to +// a target name. +func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) { + // first, set the fieldName to the field's name + fieldName = field.Name + // if a mapFunc is set, use that to override the fieldName + if mapFunc != nil { + fieldName = mapFunc(fieldName) + } + + // if there's no tag to look for, return the field name + if tagName == "" { + return "", fieldName + } + + // if this tag is not set using the normal convention in the tag, + // then return the fieldname.. this check is done because according + // to the reflect documentation: + // If the tag does not have the conventional format, + // the value returned by Get is unspecified. + // which doesn't sound great. + if !strings.Contains(string(field.Tag), tagName+":") { + return "", fieldName + } + + // at this point we're fairly sure that we have a tag, so lets pull it out + tag = field.Tag.Get(tagName) + + // if we have a mapper function, call it on the whole tag + // XXX: this is a change from the old version, which pulled out the name + // before the tagMapFunc could be run, but I think this is the right way + if tagMapFunc != nil { + tag = tagMapFunc(tag) + } + + // finally, split the options from the name + parts := strings.Split(tag, ",") + fieldName = parts[0] + + return tag, fieldName +} + +// parseOptions parses options out of a tag string, skipping the name +func parseOptions(tag string) map[string]string { + parts := strings.Split(tag, ",") + options := make(map[string]string, len(parts)) + if len(parts) > 1 { + for _, opt := range parts[1:] { + // short circuit potentially expensive split op + if strings.Contains(opt, "=") { + kv := strings.Split(opt, "=") + options[kv[0]] = kv[1] + continue + } + options[opt] = "" + } + } + return options +} + // getMapping returns a mapping for the t type, using the tagName, mapFunc and // tagMapFunc to determine the canonical names of fields. -func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) *StructMap { +func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap { m := []*FieldInfo{} root := &FieldInfo{} queue := []typeQueue{} queue = append(queue, typeQueue{Deref(t), root, ""}) +QueueLoop: for len(queue) != 0 { // pop the first item off of the queue tq := queue[0] queue = queue[1:] + + // ignore recursive field + for p := tq.fi.Parent; p != nil; p = p.Parent { + if tq.fi.Field.Type == p.Field.Type { + continue QueueLoop + } + } + nChildren := 0 if tq.t.Kind() == reflect.Struct { nChildren = tq.t.NumField() @@ -278,53 +351,31 @@ func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) // iterate through all of its fields for fieldPos := 0; fieldPos < nChildren; fieldPos++ { + f := tq.t.Field(fieldPos) - fi := FieldInfo{} - fi.Field = f - fi.Zero = reflect.New(f.Type).Elem() - fi.Options = map[string]string{} - - var tag, name string - if tagName != "" && strings.Contains(string(f.Tag), tagName+":") { - tag = f.Tag.Get(tagName) - name = tag - } else { - if mapFunc != nil { - name = mapFunc(f.Name) - } - } - - parts := strings.Split(name, ",") - if len(parts) > 1 { - name = parts[0] - for _, opt := range parts[1:] { - kv := strings.Split(opt, "=") - if len(kv) > 1 { - fi.Options[kv[0]] = kv[1] - } else { - fi.Options[kv[0]] = "" - } - } - } - - if tagMapFunc != nil { - tag = tagMapFunc(tag) - } - - fi.Name = name - - if tq.pp == "" || (tq.pp == "" && tag == "") { - fi.Path = fi.Name - } else { - fi.Path = fmt.Sprintf("%s.%s", tq.pp, fi.Name) - } + // parse the tag and the target name using the mapping options for this field + tag, name := parseName(f, tagName, mapFunc, tagMapFunc) // if the name is "-", disabled via a tag, skip it if name == "-" { continue } + fi := FieldInfo{ + Field: f, + Name: name, + Zero: reflect.New(f.Type).Elem(), + Options: parseOptions(tag), + } + + // if the path is empty this path is just the name + if tq.pp == "" { + fi.Path = fi.Name + } else { + fi.Path = tq.pp + "." + fi.Name + } + // skip unexported fields if len(f.PkgPath) != 0 && !f.Anonymous { continue diff --git a/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go b/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go index 80722443..b702f9cd 100644 --- a/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go +++ b/vendor/github.com/jmoiron/sqlx/reflectx/reflect_test.go @@ -247,11 +247,20 @@ func TestInlineStruct(t *testing.T) { } } +func TestRecursiveStruct(t *testing.T) { + type Person struct { + Parent *Person + } + m := NewMapperFunc("db", strings.ToLower) + var p *Person + m.TypeMap(reflect.TypeOf(p)) +} + func TestFieldsEmbedded(t *testing.T) { m := NewMapper("db") type Person struct { - Name string `db:"name"` + Name string `db:"name,size=64"` } type Place struct { Name string `db:"name"` @@ -311,6 +320,9 @@ func TestFieldsEmbedded(t *testing.T) { if fi.Path != "person.name" { t.Errorf("Expecting %s, got %s", "person.name", fi.Path) } + if fi.Options["size"] != "64" { + t.Errorf("Expecting %s, got %s", "64", fi.Options["size"]) + } fi = fields.GetByTraversal([]int{1, 0}) if fi == nil { @@ -508,6 +520,312 @@ func TestMapping(t *testing.T) { } } +func TestGetByTraversal(t *testing.T) { + type C struct { + C0 int + C1 int + } + type B struct { + B0 string + B1 *C + } + type A struct { + A0 int + A1 B + } + + testCases := []struct { + Index []int + ExpectedName string + ExpectNil bool + }{ + { + Index: []int{0}, + ExpectedName: "A0", + }, + { + Index: []int{1, 0}, + ExpectedName: "B0", + }, + { + Index: []int{1, 1, 1}, + ExpectedName: "C1", + }, + { + Index: []int{3, 4, 5}, + ExpectNil: true, + }, + { + Index: []int{}, + ExpectNil: true, + }, + { + Index: nil, + ExpectNil: true, + }, + } + + m := NewMapperFunc("db", func(n string) string { return n }) + tm := m.TypeMap(reflect.TypeOf(A{})) + + for i, tc := range testCases { + fi := tm.GetByTraversal(tc.Index) + if tc.ExpectNil { + if fi != nil { + t.Errorf("%d: expected nil, got %v", i, fi) + } + continue + } + + if fi == nil { + t.Errorf("%d: expected %s, got nil", i, tc.ExpectedName) + continue + } + + if fi.Name != tc.ExpectedName { + t.Errorf("%d: expected %s, got %s", i, tc.ExpectedName, fi.Name) + } + } +} + +// TestMapperMethodsByName tests Mapper methods FieldByName and TraversalsByName +func TestMapperMethodsByName(t *testing.T) { + type C struct { + C0 string + C1 int + } + type B struct { + B0 *C `db:"B0"` + B1 C `db:"B1"` + B2 string `db:"B2"` + } + type A struct { + A0 *B `db:"A0"` + B `db:"A1"` + A2 int + a3 int + } + + val := &A{ + A0: &B{ + B0: &C{C0: "0", C1: 1}, + B1: C{C0: "2", C1: 3}, + B2: "4", + }, + B: B{ + B0: nil, + B1: C{C0: "5", C1: 6}, + B2: "7", + }, + A2: 8, + } + + testCases := []struct { + Name string + ExpectInvalid bool + ExpectedValue interface{} + ExpectedIndexes []int + }{ + { + Name: "A0.B0.C0", + ExpectedValue: "0", + ExpectedIndexes: []int{0, 0, 0}, + }, + { + Name: "A0.B0.C1", + ExpectedValue: 1, + ExpectedIndexes: []int{0, 0, 1}, + }, + { + Name: "A0.B1.C0", + ExpectedValue: "2", + ExpectedIndexes: []int{0, 1, 0}, + }, + { + Name: "A0.B1.C1", + ExpectedValue: 3, + ExpectedIndexes: []int{0, 1, 1}, + }, + { + Name: "A0.B2", + ExpectedValue: "4", + ExpectedIndexes: []int{0, 2}, + }, + { + Name: "A1.B0.C0", + ExpectedValue: "", + ExpectedIndexes: []int{1, 0, 0}, + }, + { + Name: "A1.B0.C1", + ExpectedValue: 0, + ExpectedIndexes: []int{1, 0, 1}, + }, + { + Name: "A1.B1.C0", + ExpectedValue: "5", + ExpectedIndexes: []int{1, 1, 0}, + }, + { + Name: "A1.B1.C1", + ExpectedValue: 6, + ExpectedIndexes: []int{1, 1, 1}, + }, + { + Name: "A1.B2", + ExpectedValue: "7", + ExpectedIndexes: []int{1, 2}, + }, + { + Name: "A2", + ExpectedValue: 8, + ExpectedIndexes: []int{2}, + }, + { + Name: "XYZ", + ExpectInvalid: true, + ExpectedIndexes: []int{}, + }, + { + Name: "a3", + ExpectInvalid: true, + ExpectedIndexes: []int{}, + }, + } + + // build the names array from the test cases + names := make([]string, len(testCases)) + for i, tc := range testCases { + names[i] = tc.Name + } + m := NewMapperFunc("db", func(n string) string { return n }) + v := reflect.ValueOf(val) + values := m.FieldsByName(v, names) + if len(values) != len(testCases) { + t.Errorf("expected %d values, got %d", len(testCases), len(values)) + t.FailNow() + } + indexes := m.TraversalsByName(v.Type(), names) + if len(indexes) != len(testCases) { + t.Errorf("expected %d traversals, got %d", len(testCases), len(indexes)) + t.FailNow() + } + for i, val := range values { + tc := testCases[i] + traversal := indexes[i] + if !reflect.DeepEqual(tc.ExpectedIndexes, traversal) { + t.Errorf("expected %v, got %v", tc.ExpectedIndexes, traversal) + t.FailNow() + } + val = reflect.Indirect(val) + if tc.ExpectInvalid { + if val.IsValid() { + t.Errorf("%d: expected zero value, got %v", i, val) + } + continue + } + if !val.IsValid() { + t.Errorf("%d: expected valid value, got %v", i, val) + continue + } + actualValue := reflect.Indirect(val).Interface() + if !reflect.DeepEqual(tc.ExpectedValue, actualValue) { + t.Errorf("%d: expected %v, got %v", i, tc.ExpectedValue, actualValue) + } + } +} + +func TestFieldByIndexes(t *testing.T) { + type C struct { + C0 bool + C1 string + C2 int + C3 map[string]int + } + type B struct { + B1 C + B2 *C + } + type A struct { + A1 B + A2 *B + } + testCases := []struct { + value interface{} + indexes []int + expectedValue interface{} + readOnly bool + }{ + { + value: A{ + A1: B{B1: C{C0: true}}, + }, + indexes: []int{0, 0, 0}, + expectedValue: true, + readOnly: true, + }, + { + value: A{ + A2: &B{B2: &C{C1: "answer"}}, + }, + indexes: []int{1, 1, 1}, + expectedValue: "answer", + readOnly: true, + }, + { + value: &A{}, + indexes: []int{1, 1, 3}, + expectedValue: map[string]int{}, + }, + } + + for i, tc := range testCases { + checkResults := func(v reflect.Value) { + if tc.expectedValue == nil { + if !v.IsNil() { + t.Errorf("%d: expected nil, actual %v", i, v.Interface()) + } + } else { + if !reflect.DeepEqual(tc.expectedValue, v.Interface()) { + t.Errorf("%d: expected %v, actual %v", i, tc.expectedValue, v.Interface()) + } + } + } + + checkResults(FieldByIndexes(reflect.ValueOf(tc.value), tc.indexes)) + if tc.readOnly { + checkResults(FieldByIndexesReadOnly(reflect.ValueOf(tc.value), tc.indexes)) + } + } +} + +func TestMustBe(t *testing.T) { + typ := reflect.TypeOf(E1{}) + mustBe(typ, reflect.Struct) + + defer func() { + if r := recover(); r != nil { + valueErr, ok := r.(*reflect.ValueError) + if !ok { + t.Errorf("unexpected Method: %s", valueErr.Method) + t.Error("expected panic with *reflect.ValueError") + return + } + if valueErr.Method != "github.com/jmoiron/sqlx/reflectx.TestMustBe" { + } + if valueErr.Kind != reflect.String { + t.Errorf("unexpected Kind: %s", valueErr.Kind) + } + } else { + t.Error("expected panic") + } + }() + + typ = reflect.TypeOf("string") + mustBe(typ, reflect.Struct) + t.Error("got here, didn't expect to") +} + type E1 struct { A int } diff --git a/vendor/github.com/jmoiron/sqlx/sqlx.go b/vendor/github.com/jmoiron/sqlx/sqlx.go index 74e0a31b..4859d5ac 100644 --- a/vendor/github.com/jmoiron/sqlx/sqlx.go +++ b/vendor/github.com/jmoiron/sqlx/sqlx.go @@ -10,6 +10,7 @@ import ( "path/filepath" "reflect" "strings" + "sync" "github.com/jmoiron/sqlx/reflectx" ) @@ -17,7 +18,7 @@ import ( // Although the NameMapper is convenient, in practice it should not // be relied on except for application code. If you are writing a library // that uses sqlx, you should be aware that the name mappings you expect -// can be overridded by your user's application. +// can be overridden by your user's application. // NameMapper is used to map column names to struct field names. By default, // it uses strings.ToLower to lowercase struct field names. It can be set @@ -30,8 +31,14 @@ var origMapper = reflect.ValueOf(NameMapper) // importers have time to customize the NameMapper. var mpr *reflectx.Mapper +// mprMu protects mpr. +var mprMu sync.Mutex + // mapper returns a valid mapper using the configured NameMapper func. func mapper() *reflectx.Mapper { + mprMu.Lock() + defer mprMu.Unlock() + if mpr == nil { mpr = reflectx.NewMapperFunc("db", NameMapper) } else if origMapper != reflect.ValueOf(NameMapper) { @@ -289,21 +296,26 @@ func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, e } // NamedQuery using this DB. +// Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(db, query, arg) } // NamedExec using this DB. +// Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(db, query, arg) } // Select using this DB. +// Any placeholder parameters are replaced with supplied args. func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { return Select(db, dest, query, args...) } // Get using this DB. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { return Get(db, dest, query, args...) } @@ -328,6 +340,7 @@ func (db *DB) Beginx() (*Tx, error) { } // Queryx queries the database and returns an *sqlx.Rows. +// Any placeholder parameters are replaced with supplied args. func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := db.DB.Query(query, args...) if err != nil { @@ -337,12 +350,14 @@ func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { } // QueryRowx queries the database and returns an *sqlx.Row. +// Any placeholder parameters are replaced with supplied args. func (db *DB) QueryRowx(query string, args ...interface{}) *Row { rows, err := db.DB.Query(query, args...) return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} } // MustExec (panic) runs MustExec using this database. +// Any placeholder parameters are replaced with supplied args. func (db *DB) MustExec(query string, args ...interface{}) sql.Result { return MustExec(db, query, args...) } @@ -387,21 +402,25 @@ func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, e } // NamedQuery within a transaction. +// Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(tx, query, arg) } // NamedExec a named query within a transaction. +// Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(tx, query, arg) } // Select within a transaction. +// Any placeholder parameters are replaced with supplied args. func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { return Select(tx, dest, query, args...) } // Queryx within a transaction. +// Any placeholder parameters are replaced with supplied args. func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := tx.Tx.Query(query, args...) if err != nil { @@ -411,17 +430,21 @@ func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { } // QueryRowx within a transaction. +// Any placeholder parameters are replaced with supplied args. func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { rows, err := tx.Tx.Query(query, args...) return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} } // Get within a transaction. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { return Get(tx, dest, query, args...) } // MustExec runs MustExec within a transaction. +// Any placeholder parameters are replaced with supplied args. func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { return MustExec(tx, query, args...) } @@ -478,28 +501,34 @@ func (s *Stmt) Unsafe() *Stmt { } // Select using the prepared statement. +// Any placeholder parameters are replaced with supplied args. func (s *Stmt) Select(dest interface{}, args ...interface{}) error { return Select(&qStmt{s}, dest, "", args...) } // Get using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. func (s *Stmt) Get(dest interface{}, args ...interface{}) error { return Get(&qStmt{s}, dest, "", args...) } // MustExec (panic) using this statement. Note that the query portion of the error // output will be blank, as Stmt does not expose its query. +// Any placeholder parameters are replaced with supplied args. func (s *Stmt) MustExec(args ...interface{}) sql.Result { return MustExec(&qStmt{s}, "", args...) } // QueryRowx using this statement. +// Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryRowx(args ...interface{}) *Row { qs := &qStmt{s} return qs.QueryRowx("", args...) } // Queryx using this statement. +// Any placeholder parameters are replaced with supplied args. func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { qs := &qStmt{s} return qs.Queryx("", args...) @@ -576,7 +605,7 @@ func (r *Rows) StructScan(dest interface{}) error { r.fields = m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(r.fields); err != nil && !r.unsafe { - return fmt.Errorf("missing destination name %s", columns[f]) + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } r.values = make([]interface{}, len(columns)) r.started = true @@ -626,6 +655,7 @@ func Preparex(p Preparer, query string) (*Stmt, error) { // into dest, which must be a slice. If the slice elements are scannable, then // the result set must have only one column. Otherwise, StructScan is used. // The *sql.Rows are closed automatically. +// Any placeholder parameters are replaced with supplied args. func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { rows, err := q.Queryx(query, args...) if err != nil { @@ -639,6 +669,8 @@ func Select(q Queryer, dest interface{}, query string, args ...interface{}) erro // Get does a QueryRow using the provided Queryer, and scans the resulting row // to dest. If dest is scannable, the result must only have one column. Otherwise, // StructScan is used. Get will return sql.ErrNoRows like row.Scan would. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { r := q.QueryRowx(query, args...) return r.scanAny(dest, false) @@ -669,6 +701,7 @@ func LoadFile(e Execer, path string) (*sql.Result, error) { } // MustExec execs the query using e and panics if there was an error. +// Any placeholder parameters are replaced with supplied args. func MustExec(e Execer, query string, args ...interface{}) sql.Result { res, err := e.Exec(query, args...) if err != nil { @@ -691,6 +724,10 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { if r.err != nil { return r.err } + if r.rows == nil { + r.err = sql.ErrNoRows + return r.err + } defer r.rows.Close() v := reflect.ValueOf(dest) @@ -726,7 +763,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { fields := m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(fields); err != nil && !r.unsafe { - return fmt.Errorf("missing destination name %s", columns[f]) + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values := make([]interface{}, len(columns)) @@ -779,7 +816,7 @@ func SliceScan(r ColScanner) ([]interface{}, error) { // executes SQL from input). Please do not use this as a primary interface! // This will modify the map sent to it in place, so reuse the same map with // care. Columns which occur more than once in the result will overwrite -// eachother! +// each other! func MapScan(r ColScanner, dest map[string]interface{}) error { // ignore r.started, since we needn't use reflect for anything. columns, err := r.Columns() @@ -892,7 +929,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { fields := m.TraversalsByName(base, columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(fields); err != nil && !isUnsafe(rows) { - return fmt.Errorf("missing destination name %s", columns[f]) + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) @@ -902,6 +939,9 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { v = reflect.Indirect(vp) err = fieldsByTraversal(v, fields, values, true) + if err != nil { + return err + } // scan into the struct field pointers and append to our results err = rows.Scan(values...) @@ -919,6 +959,9 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { for rows.Next() { vp = reflect.New(base) err = rows.Scan(vp.Interface()) + if err != nil { + return err + } // append if isPtr { direct.Set(reflect.Append(direct, vp)) @@ -937,7 +980,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { // anyway) works on a rows object. // StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice. -// StructScan will scan in the entire rows result, so if you need do not want to +// StructScan will scan in the entire rows result, so if you do not want to // allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan. // If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default. func StructScan(rows rowsi, dest interface{}) error { diff --git a/vendor/github.com/jmoiron/sqlx/sqlx_context.go b/vendor/github.com/jmoiron/sqlx/sqlx_context.go new file mode 100644 index 00000000..0b171451 --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/sqlx_context.go @@ -0,0 +1,335 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" + "fmt" + "io/ioutil" + "path/filepath" + "reflect" +) + +// ConnectContext to a database and verify with a ping. +func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) { + db, err := Open(driverName, dataSourceName) + if err != nil { + return db, err + } + err = db.PingContext(ctx) + return db, err +} + +// QueryerContext is an interface used by GetContext and SelectContext +type QueryerContext interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) + QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row +} + +// PreparerContext is an interface used by PreparexContext. +type PreparerContext interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) +} + +// ExecerContext is an interface used by MustExecContext and LoadFileContext +type ExecerContext interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// ExtContext is a union interface which can bind, query, and exec, with Context +// used by NamedQueryContext and NamedExecContext. +type ExtContext interface { + binder + QueryerContext + ExecerContext +} + +// SelectContext executes a query using the provided Queryer, and StructScans +// each row into dest, which must be a slice. If the slice elements are +// scannable, then the result set must have only one column. Otherwise, +// StructScan is used. The *sql.Rows are closed automatically. +// Any placeholder parameters are replaced with supplied args. +func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { + rows, err := q.QueryxContext(ctx, query, args...) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// PreparexContext prepares a statement. +// +// The provided context is used for the preparation of the statement, not for +// the execution of the statement. +func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) { + s, err := p.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err +} + +// GetContext does a QueryRow using the provided Queryer, and scans the +// resulting row to dest. If dest is scannable, the result must only have one +// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like +// row.Scan would. Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { + r := q.QueryRowxContext(ctx, query, args...) + return r.scanAny(dest, false) +} + +// LoadFileContext exec's every statement in a file (as a single call to Exec). +// LoadFileContext may return a nil *sql.Result if errors are encountered +// locating or reading the file at path. LoadFile reads the entire file into +// memory, so it is not suitable for loading large data dumps, but can be useful +// for initializing schemas or loading indexes. +// +// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 +// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting +// this by requiring something with DriverName() and then attempting to split the +// queries will be difficult to get right, and its current driver-specific behavior +// is deemed at least not complex in its incorrectness. +func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) { + realpath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + contents, err := ioutil.ReadFile(realpath) + if err != nil { + return nil, err + } + res, err := e.ExecContext(ctx, string(contents)) + return &res, err +} + +// MustExecContext execs the query using e and panics if there was an error. +// Any placeholder parameters are replaced with supplied args. +func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result { + res, err := e.ExecContext(ctx, query, args...) + if err != nil { + panic(err) + } + return res +} + +// PrepareNamedContext returns an sqlx.NamedStmt +func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { + return prepareNamedContext(ctx, db, query) +} + +// NamedQueryContext using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) { + return NamedQueryContext(ctx, db, query, arg) +} + +// NamedExecContext using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + return NamedExecContext(ctx, db, query, arg) +} + +// SelectContext using this DB. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return SelectContext(ctx, db, dest, query, args...) +} + +// GetContext using this DB. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return GetContext(ctx, db, dest, query, args...) +} + +// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. +// +// The provided context is used for the preparation of the statement, not for +// the execution of the statement. +func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) { + return PreparexContext(ctx, db, query) +} + +// QueryxContext queries the database and returns an *sqlx.Rows. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := db.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// QueryRowxContext queries the database and returns an *sqlx.Row. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.DB.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} +} + +// MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead +// of an *sql.Tx. +// +// The provided context is used until the transaction is committed or rolled +// back. If the context is canceled, the sql package will roll back the +// transaction. Tx.Commit will return an error if the context provided to +// MustBeginContext is canceled. +func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx { + tx, err := db.BeginTxx(ctx, opts) + if err != nil { + panic(err) + } + return tx +} + +// MustExecContext (panic) runs MustExec using this database. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { + return MustExecContext(ctx, db, query, args...) +} + +// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an +// *sql.Tx. +// +// The provided context is used until the transaction is committed or rolled +// back. If the context is canceled, the sql package will roll back the +// transaction. Tx.Commit will return an error if the context provided to +// BeginxContext is canceled. +func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := db.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// StmtxContext returns a version of the prepared statement which runs within a +// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt. +func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt { + var s *sql.Stmt + switch v := stmt.(type) { + case Stmt: + s = v.Stmt + case *Stmt: + s = v.Stmt + case sql.Stmt: + s = &v + case *sql.Stmt: + s = v + default: + panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) + } + return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper} +} + +// NamedStmtContext returns a version of the prepared statement which runs +// within a transaction. +func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt { + return &NamedStmt{ + QueryString: stmt.QueryString, + Params: stmt.Params, + Stmt: tx.StmtxContext(ctx, stmt.Stmt), + } +} + +// MustExecContext runs MustExecContext within a transaction. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { + return MustExecContext(ctx, tx, query, args...) +} + +// QueryxContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := tx.Tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err +} + +// SelectContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return SelectContext(ctx, tx, dest, query, args...) +} + +// GetContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return GetContext(ctx, tx, dest, query, args...) +} + +// QueryRowxContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := tx.Tx.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} +} + +// NamedExecContext using this Tx. +// Any named placeholder parameters are replaced with fields from arg. +func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + return NamedExecContext(ctx, tx, query, arg) +} + +// SelectContext using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error { + return SelectContext(ctx, &qStmt{s}, dest, "", args...) +} + +// GetContext using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error { + return GetContext(ctx, &qStmt{s}, dest, "", args...) +} + +// MustExecContext (panic) using this statement. Note that the query portion of +// the error output will be blank, as Stmt does not expose its query. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { + return MustExecContext(ctx, &qStmt{s}, "", args...) +} + +// QueryRowxContext using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row { + qs := &qStmt{s} + return qs.QueryRowxContext(ctx, "", args...) +} + +// QueryxContext using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) { + qs := &qStmt{s} + return qs.QueryxContext(ctx, "", args...) +} + +func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return q.Stmt.QueryContext(ctx, args...) +} + +func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := q.Stmt.QueryContext(ctx, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err +} + +func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := q.Stmt.QueryContext(ctx, args...) + return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} +} + +func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return q.Stmt.ExecContext(ctx, args...) +} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx_context_test.go b/vendor/github.com/jmoiron/sqlx/sqlx_context_test.go new file mode 100644 index 00000000..85e112bd --- /dev/null +++ b/vendor/github.com/jmoiron/sqlx/sqlx_context_test.go @@ -0,0 +1,1344 @@ +// +build go1.8 + +// The following environment variables, if set, will be used: +// +// * SQLX_SQLITE_DSN +// * SQLX_POSTGRES_DSN +// * SQLX_MYSQL_DSN +// +// Set any of these variables to 'skip' to skip them. Note that for MySQL, +// the string '?parseTime=True' will be appended to the DSN if it's not there +// already. +// +package sqlx + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "log" + "strings" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx/reflectx" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" +) + +func MultiExecContext(ctx context.Context, e ExecerContext, query string) { + stmts := strings.Split(query, ";\n") + if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { + stmts = stmts[:len(stmts)-1] + } + for _, s := range stmts { + _, err := e.ExecContext(ctx, s) + if err != nil { + fmt.Println(err, s) + } + } +} + +func RunWithSchemaContext(ctx context.Context, schema Schema, t *testing.T, test func(ctx context.Context, db *DB, t *testing.T)) { + runner := func(ctx context.Context, db *DB, t *testing.T, create, drop string) { + defer func() { + MultiExecContext(ctx, db, drop) + }() + + MultiExecContext(ctx, db, create) + test(ctx, db, t) + } + + if TestPostgres { + create, drop := schema.Postgres() + runner(ctx, pgdb, t, create, drop) + } + if TestSqlite { + create, drop := schema.Sqlite3() + runner(ctx, sldb, t, create, drop) + } + if TestMysql { + create, drop := schema.MySQL() + runner(ctx, mysqldb, t, create, drop) + } +} + +func loadDefaultFixtureContext(ctx context.Context, db *DB, t *testing.T) { + tx := db.MustBeginTx(ctx, nil) + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + if db.DriverName() == "mysql" { + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") + } else { + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") + } + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") + tx.Commit() +} + +// Test a new backwards compatible feature, that missing scan destinations +// will silently scan into sql.RawText rather than failing/panicing +func TestMissingNamesContextContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + type PersonPlus struct { + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string + //AddedAt time.Time `db:"added_at"` + } + + // test Select first + pps := []PersonPlus{} + // pps lacks added_at destination + err := db.SelectContext(ctx, &pps, "SELECT * FROM person") + if err == nil { + t.Error("Expected missing name from Select to fail, but it did not.") + } + + // test Get + pp := PersonPlus{} + err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") + if err == nil { + t.Error("Expected missing name Get to fail, but it did not.") + } + + // test naked StructScan + pps = []PersonPlus{} + rows, err := db.QueryContext(ctx, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Fatal(err) + } + rows.Next() + err = StructScan(rows, &pps) + if err == nil { + t.Error("Expected missing name in StructScan to fail, but it did not.") + } + rows.Close() + + // now try various things with unsafe set. + db = db.Unsafe() + pps = []PersonPlus{} + err = db.SelectContext(ctx, &pps, "SELECT * FROM person") + if err != nil { + t.Error(err) + } + + // test Get + pp = PersonPlus{} + err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Error(err) + } + + // test naked StructScan + pps = []PersonPlus{} + rowsx, err := db.QueryxContext(ctx, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Fatal(err) + } + rowsx.Next() + err = StructScan(rowsx, &pps) + if err != nil { + t.Error(err) + } + rowsx.Close() + + // test Named stmt + if !isUnsafe(db) { + t.Error("Expected db to be unsafe, but it isn't") + } + nstmt, err := db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) + if err != nil { + t.Fatal(err) + } + // its internal stmt should be marked unsafe + if !nstmt.Stmt.unsafe { + t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety") + } + pps = []PersonPlus{} + err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) + if err != nil { + t.Fatal(err) + } + if len(pps) != 1 { + t.Errorf("Expected 1 person back, got %d", len(pps)) + } + + // test it with a safe db + db.unsafe = false + if isUnsafe(db) { + t.Error("expected db to be safe but it isn't") + } + nstmt, err = db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) + if err != nil { + t.Fatal(err) + } + // it should be safe + if isUnsafe(nstmt) { + t.Error("NamedStmt did not inherit safety") + } + nstmt.Unsafe() + if !isUnsafe(nstmt) { + t.Error("expected newly unsafed NamedStmt to be unsafe") + } + pps = []PersonPlus{} + err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) + if err != nil { + t.Fatal(err) + } + if len(pps) != 1 { + t.Errorf("Expected 1 person back, got %d", len(pps)) + } + + }) +} + +func TestEmbeddedStructsContextContext(t *testing.T) { + type Loop1 struct{ Person } + type Loop2 struct{ Loop1 } + type Loop3 struct{ Loop2 } + + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + peopleAndPlaces := []PersonPlace{} + err := db.SelectContext( + ctx, + &peopleAndPlaces, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Fatal(err) + } + for _, pp := range peopleAndPlaces { + if len(pp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(pp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + } + + // test embedded structs with StructScan + rows, err := db.QueryxContext( + ctx, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Error(err) + } + + perp := PersonPlace{} + rows.Next() + err = rows.StructScan(&perp) + if err != nil { + t.Error(err) + } + + if len(perp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(perp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + + rows.Close() + + // test the same for embedded pointer structs + peopleAndPlacesPtrs := []PersonPlacePtr{} + err = db.SelectContext( + ctx, + &peopleAndPlacesPtrs, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Fatal(err) + } + for _, pp := range peopleAndPlacesPtrs { + if len(pp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(pp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + } + + // test "deep nesting" + l3s := []Loop3{} + err = db.SelectContext(ctx, &l3s, `select * from person`) + if err != nil { + t.Fatal(err) + } + for _, l3 := range l3s { + if len(l3.Loop2.Loop1.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + } + + // test "embed conflicts" + ec := []EmbedConflict{} + err = db.SelectContext(ctx, &ec, `select * from person`) + // I'm torn between erroring here or having some kind of working behavior + // in order to allow for more flexibility in destination structs + if err != nil { + t.Errorf("Was not expecting an error on embed conflicts.") + } + }) +} + +func TestJoinQueryContext(t *testing.T) { + type Employee struct { + Name string + ID int64 + // BossID is an id into the employee table + BossID sql.NullInt64 `db:"boss_id"` + } + type Boss Employee + + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + + var employees []struct { + Employee + Boss `db:"boss"` + } + + err := db.SelectContext(ctx, + &employees, + `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees + JOIN employees AS boss ON employees.boss_id = boss.id`) + if err != nil { + t.Fatal(err) + } + + for _, em := range employees { + if len(em.Employee.Name) == 0 { + t.Errorf("Expected non zero lengthed name.") + } + if em.Employee.BossID.Int64 != em.Boss.ID { + t.Errorf("Expected boss ids to match") + } + } + }) +} + +func TestJoinQueryNamedPointerStructsContext(t *testing.T) { + type Employee struct { + Name string + ID int64 + // BossID is an id into the employee table + BossID sql.NullInt64 `db:"boss_id"` + } + type Boss Employee + + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + + var employees []struct { + Emp1 *Employee `db:"emp1"` + Emp2 *Employee `db:"emp2"` + *Boss `db:"boss"` + } + + err := db.SelectContext(ctx, + &employees, + `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", + emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", + boss.id "boss.id", boss.name "boss.name" FROM employees AS emp + JOIN employees AS boss ON emp.boss_id = boss.id + `) + if err != nil { + t.Fatal(err) + } + + for _, em := range employees { + if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 { + t.Errorf("Expected non zero lengthed name.") + } + if em.Emp1.BossID.Int64 != em.Boss.ID || em.Emp2.BossID.Int64 != em.Boss.ID { + t.Errorf("Expected boss ids to match") + } + } + }) +} + +func TestSelectSliceMapTimeContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + rows, err := db.QueryxContext(ctx, "SELECT * FROM person") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + _, err := rows.SliceScan() + if err != nil { + t.Error(err) + } + } + + rows, err = db.QueryxContext(ctx, "SELECT * FROM person") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + m := map[string]interface{}{} + err := rows.MapScan(m) + if err != nil { + t.Error(err) + } + } + + }) +} + +func TestNilReceiverContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + var p *Person + err := db.GetContext(ctx, p, "SELECT * FROM person LIMIT 1") + if err == nil { + t.Error("Expected error when getting into nil struct ptr.") + } + var pp *[]Person + err = db.SelectContext(ctx, pp, "SELECT * FROM person") + if err == nil { + t.Error("Expected an error when selecting into nil slice ptr.") + } + }) +} + +func TestNamedQueryContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE place ( + id integer PRIMARY KEY, + name text NULL + ); + CREATE TABLE person ( + first_name text NULL, + last_name text NULL, + email text NULL + ); + CREATE TABLE placeperson ( + first_name text NULL, + last_name text NULL, + email text NULL, + place_id integer NULL + ); + CREATE TABLE jsperson ( + "FIRST" text NULL, + last_name text NULL, + "EMAIL" text NULL + );`, + drop: ` + drop table person; + drop table jsperson; + drop table place; + drop table placeperson; + `, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type Person struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString + } + + p := Person{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "doe", Valid: true}, + Email: sql.NullString{String: "ben@doe.com", Valid: true}, + } + + q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` + _, err := db.NamedExecContext(ctx, q1, p) + if err != nil { + log.Fatal(err) + } + + p2 := &Person{} + rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(p2) + if err != nil { + t.Error(err) + } + if p2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + p2.FirstName.String) + } + if p2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + p2.LastName.String) + } + } + + // these are tests for #73; they verify that named queries work if you've + // changed the db mapper. This code checks both NamedQuery "ad-hoc" style + // queries and NamedStmt queries, which use different code paths internally. + old := *db.Mapper + + type JSONPerson struct { + FirstName sql.NullString `json:"FIRST"` + LastName sql.NullString `json:"last_name"` + Email sql.NullString + } + + jp := JSONPerson{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "smith", Valid: true}, + Email: sql.NullString{String: "ben@smith.com", Valid: true}, + } + + db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper) + + // prepare queries for case sensitivity to test our ToUpper function. + // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line + // strings are `` we use "" by default and swap out for MySQL + pdb := func(s string, db *DB) string { + if db.DriverName() == "mysql" { + return strings.Replace(s, `"`, "`", -1) + } + return s + } + + q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` + _, err = db.NamedExecContext(ctx, pdb(q1, db), jp) + if err != nil { + t.Fatal(err, db.DriverName()) + } + + // Checks that a person pulled out of the db matches the one we put in + check := func(t *testing.T, rows *Rows) { + jp = JSONPerson{} + for rows.Next() { + err = rows.StructScan(&jp) + if err != nil { + t.Error(err) + } + if jp.FirstName.String != "ben" { + t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName()) + } + if jp.LastName.String != "smith" { + t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName()) + } + if jp.Email.String != "ben@smith.com" { + t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName()) + } + } + } + + ns, err := db.PrepareNamed(pdb(` + SELECT * FROM jsperson + WHERE + "FIRST"=:FIRST AND + last_name=:last_name AND + "EMAIL"=:EMAIL + `, db)) + + if err != nil { + t.Fatal(err) + } + rows, err = ns.QueryxContext(ctx, jp) + if err != nil { + t.Fatal(err) + } + + check(t, rows) + + // Check exactly the same thing, but with db.NamedQuery, which does not go + // through the PrepareNamed/NamedStmt path. + rows, err = db.NamedQueryContext(ctx, pdb(` + SELECT * FROM jsperson + WHERE + "FIRST"=:FIRST AND + last_name=:last_name AND + "EMAIL"=:EMAIL + `, db), jp) + if err != nil { + t.Fatal(err) + } + + check(t, rows) + + db.Mapper = &old + + // Test nested structs + type Place struct { + ID int `db:"id"` + Name sql.NullString `db:"name"` + } + type PlacePerson struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString + Place Place `db:"place"` + } + + pl := Place{ + Name: sql.NullString{String: "myplace", Valid: true}, + } + + pp := PlacePerson{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "doe", Valid: true}, + Email: sql.NullString{String: "ben@doe.com", Valid: true}, + } + + q2 := `INSERT INTO place (id, name) VALUES (1, :name)` + _, err = db.NamedExecContext(ctx, q2, pl) + if err != nil { + log.Fatal(err) + } + + id := 1 + pp.Place.ID = id + + q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExecContext(ctx, q3, pp) + if err != nil { + log.Fatal(err) + } + + pp2 := &PlacePerson{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + first_name, + last_name, + email, + place.id AS "place.id", + place.name AS "place.name" + FROM placeperson + INNER JOIN place ON place.id = placeperson.place_id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp2) + if err != nil { + t.Error(err) + } + if pp2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) + } + if pp2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + pp2.LastName.String) + } + if pp2.Place.Name.String != "myplace" { + t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) + } + if pp2.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) + } + } + }) +} + +func TestNilInsertsContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE tt ( + id integer, + value text NULL DEFAULT NULL + );`, + drop: "drop table tt;", + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type TT struct { + ID int + Value *string + } + var v, v2 TT + r := db.Rebind + + db.MustExecContext(ctx, r(`INSERT INTO tt (id) VALUES (1)`)) + db.GetContext(ctx, &v, r(`SELECT * FROM tt`)) + if v.ID != 1 { + t.Errorf("Expecting id of 1, got %v", v.ID) + } + if v.Value != nil { + t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) + } + + v.ID = 2 + // NOTE: this incidentally uncovered a bug which was that named queries with + // pointer destinations would not work if the passed value here was not addressable, + // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for + // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly + // function. This next line is important as it provides the only coverage for this. + db.NamedExecContext(ctx, `INSERT INTO tt (id, value) VALUES (:id, :value)`, v) + + db.GetContext(ctx, &v2, r(`SELECT * FROM tt WHERE id=2`)) + if v.ID != v2.ID { + t.Errorf("%v != %v", v.ID, v2.ID) + } + if v2.Value != nil { + t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) + } + }) +} + +func TestScanErrorContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE kv ( + k text, + v integer + );`, + drop: `drop table kv;`, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type WrongTypes struct { + K int + V string + } + _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) + if err != nil { + t.Error(err) + } + + rows, err := db.QueryxContext(ctx, "SELECT * FROM kv") + if err != nil { + t.Error(err) + } + for rows.Next() { + var wt WrongTypes + err := rows.StructScan(&wt) + if err == nil { + t.Errorf("%s: Scanning wrong types into keys should have errored.", db.DriverName()) + } + } + }) +} + +// FIXME: this function is kinda big but it slows things down to be constantly +// loading and reloading the schema.. + +func TestUsageContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + slicemembers := []SliceMember{} + err := db.SelectContext(ctx, &slicemembers, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + people := []Person{} + + err = db.SelectContext(ctx, &people, "SELECT * FROM person ORDER BY first_name ASC") + if err != nil { + t.Fatal(err) + } + + jason, john := people[0], people[1] + if jason.FirstName != "Jason" { + t.Errorf("Expecting FirstName of Jason, got %s", jason.FirstName) + } + if jason.LastName != "Moiron" { + t.Errorf("Expecting LastName of Moiron, got %s", jason.LastName) + } + if jason.Email != "jmoiron@jmoiron.net" { + t.Errorf("Expecting Email of jmoiron@jmoiron.net, got %s", jason.Email) + } + if john.FirstName != "John" || john.LastName != "Doe" || john.Email != "johndoeDNE@gmail.net" { + t.Errorf("John Doe's person record not what expected: Got %v\n", john) + } + + jason = Person{} + err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") + + if err != nil { + t.Fatal(err) + } + if jason.FirstName != "Jason" { + t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) + } + + err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") + if err == nil { + t.Errorf("Expecting an error, got nil\n") + } + if err != sql.ErrNoRows { + t.Errorf("Expected sql.ErrNoRows, got %v\n", err) + } + + // The following tests check statement reuse, which was actually a problem + // due to copying being done when creating Stmt's which was eventually removed + stmt1, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + jason = Person{} + + row := stmt1.QueryRowx("DoesNotExist") + row.Scan(&jason) + row = stmt1.QueryRowx("DoesNotExist") + row.Scan(&jason) + + err = stmt1.GetContext(ctx, &jason, "DoesNotExist User") + if err == nil { + t.Error("Expected an error") + } + err = stmt1.GetContext(ctx, &jason, "DoesNotExist User 2") + if err == nil { + t.Fatal(err) + } + + stmt2, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + jason = Person{} + tx, err := db.Beginx() + if err != nil { + t.Fatal(err) + } + tstmt2 := tx.Stmtx(stmt2) + row2 := tstmt2.QueryRowx("Jason") + err = row2.StructScan(&jason) + if err != nil { + t.Error(err) + } + tx.Commit() + + places := []*Place{} + err = db.SelectContext(ctx, &places, "SELECT telcode FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + usa, singsing, honkers := places[0], places[1], places[2] + + if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { + t.Errorf("Expected integer telcodes to work, got %#v", places) + } + + placesptr := []PlacePtr{} + err = db.SelectContext(ctx, &placesptr, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Error(err) + } + //fmt.Printf("%#v\n%#v\n%#v\n", placesptr[0], placesptr[1], placesptr[2]) + + // if you have null fields and use SELECT *, you must use sql.Null* in your struct + // this test also verifies that you can use either a []Struct{} or a []*Struct{} + places2 := []Place{} + err = db.SelectContext(ctx, &places2, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + usa, singsing, honkers = &places2[0], &places2[1], &places2[2] + + // this should return a type error that &p is not a pointer to a struct slice + p := Place{} + err = db.SelectContext(ctx, &p, "SELECT * FROM place ORDER BY telcode ASC") + if err == nil { + t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") + } + + // this should be an error + pl := []Place{} + err = db.SelectContext(ctx, pl, "SELECT * FROM place ORDER BY telcode ASC") + if err == nil { + t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") + } + + if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { + t.Errorf("Expected integer telcodes to work, got %#v", places) + } + + stmt, err := db.PreparexContext(ctx, db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) + if err != nil { + t.Error(err) + } + + places = []*Place{} + err = stmt.SelectContext(ctx, &places, 10) + if len(places) != 2 { + t.Error("Expected 2 places, got 0.") + } + if err != nil { + t.Fatal(err) + } + singsing, honkers = places[0], places[1] + if singsing.TelCode != 65 || honkers.TelCode != 852 { + t.Errorf("Expected the right telcodes, got %#v", places) + } + + rows, err := db.QueryxContext(ctx, "SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + place := Place{} + for rows.Next() { + err = rows.StructScan(&place) + if err != nil { + t.Fatal(err) + } + } + + rows, err = db.QueryxContext(ctx, "SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + m := map[string]interface{}{} + for rows.Next() { + err = rows.MapScan(m) + if err != nil { + t.Fatal(err) + } + _, ok := m["country"] + if !ok { + t.Errorf("Expected key `country` in map but could not find it (%#v)\n", m) + } + } + + rows, err = db.QueryxContext(ctx, "SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + s, err := rows.SliceScan() + if err != nil { + t.Error(err) + } + if len(s) != 3 { + t.Errorf("Expected 3 columns in result, got %d\n", len(s)) + } + } + + // test advanced querying + // test that NamedExec works with a map as well as a struct + _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ + "first": "Bin", + "last": "Smuth", + "email": "bensmith@allblacks.nz", + }) + if err != nil { + t.Fatal(err) + } + + // ensure that if the named param happens right at the end it still works + // ensure that NamedQuery works with a map[string]interface{} + rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) + if err != nil { + t.Fatal(err) + } + + ben := &Person{} + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Bin" { + t.Fatal("Expected first name of `Bin`, got " + ben.FirstName) + } + if ben.LastName != "Smuth" { + t.Fatal("Expected first name of `Smuth`, got " + ben.LastName) + } + } + + ben.FirstName = "Ben" + ben.LastName = "Smith" + ben.Email = "binsmuth@allblacks.nz" + + // Insert via a named query using the struct + _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) + + if err != nil { + t.Fatal(err) + } + + rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", ben) + if err != nil { + t.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Ben" { + t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) + } + if ben.LastName != "Smith" { + t.Fatal("Expected first name of `Smith`, got " + ben.LastName) + } + } + // ensure that Get does not panic on emppty result set + person := &Person{} + err = db.GetContext(ctx, person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") + if err == nil { + t.Fatal("Should have got an error for Get on non-existant row.") + } + + // lets test prepared statements some more + + stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + rows, err = stmt.QueryxContext(ctx, "Ben") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Ben" { + t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) + } + if ben.LastName != "Smith" { + t.Fatal("Expected first name of `Smith`, got " + ben.LastName) + } + } + + john = Person{} + stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Error(err) + } + err = stmt.GetContext(ctx, &john, "John") + if err != nil { + t.Error(err) + } + + // test name mapping + // THIS USED TO WORK BUT WILL NO LONGER WORK. + db.MapperFunc(strings.ToUpper) + rsa := CPlace{} + err = db.GetContext(ctx, &rsa, "SELECT * FROM capplace;") + if err != nil { + t.Error(err, "in db:", db.DriverName()) + } + db.MapperFunc(strings.ToLower) + + // create a copy and change the mapper, then verify the copy behaves + // differently from the original. + dbCopy := NewDb(db.DB, db.DriverName()) + dbCopy.MapperFunc(strings.ToUpper) + err = dbCopy.GetContext(ctx, &rsa, "SELECT * FROM capplace;") + if err != nil { + fmt.Println(db.DriverName()) + t.Error(err) + } + + err = db.GetContext(ctx, &rsa, "SELECT * FROM cappplace;") + if err == nil { + t.Error("Expected no error, got ", err) + } + + // test base type slices + var sdest []string + rows, err = db.QueryxContext(ctx, "SELECT email FROM person ORDER BY email ASC;") + if err != nil { + t.Error(err) + } + err = scanAll(rows, &sdest, false) + if err != nil { + t.Error(err) + } + + // test Get with base types + var count int + err = db.GetContext(ctx, &count, "SELECT count(*) FROM person;") + if err != nil { + t.Error(err) + } + if count != len(sdest) { + t.Errorf("Expected %d == %d (count(*) vs len(SELECT ..)", count, len(sdest)) + } + + // test Get and Select with time.Time, #84 + var addedAt time.Time + err = db.GetContext(ctx, &addedAt, "SELECT added_at FROM person LIMIT 1;") + if err != nil { + t.Error(err) + } + + var addedAts []time.Time + err = db.SelectContext(ctx, &addedAts, "SELECT added_at FROM person;") + if err != nil { + t.Error(err) + } + + // test it on a double pointer + var pcount *int + err = db.GetContext(ctx, &pcount, "SELECT count(*) FROM person;") + if err != nil { + t.Error(err) + } + if *pcount != count { + t.Errorf("expected %d = %d", *pcount, count) + } + + // test Select... + sdest = []string{} + err = db.SelectContext(ctx, &sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") + if err != nil { + t.Error(err) + } + expected := []string{"Ben", "Bin", "Jason", "John"} + for i, got := range sdest { + if got != expected[i] { + t.Errorf("Expected %d result to be %s, but got %s", i, expected[i], got) + } + } + + var nsdest []sql.NullString + err = db.SelectContext(ctx, &nsdest, "SELECT city FROM place ORDER BY city ASC") + if err != nil { + t.Error(err) + } + for _, val := range nsdest { + if val.Valid && val.String != "New York" { + t.Errorf("expected single valid result to be `New York`, but got %s", val.String) + } + } + }) +} + +// tests that sqlx will not panic when the wrong driver is passed because +// of an automatic nil dereference in sqlx.Open(), which was fixed. +func TestDoNotPanicOnConnectContext(t *testing.T) { + _, err := ConnectContext(context.Background(), "bogus", "hehe") + if err == nil { + t.Errorf("Should return error when using bogus driverName") + } +} + +func TestEmbeddedMapsContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE message ( + string text, + properties text + );`, + drop: `drop table message;`, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + messages := []Message{ + {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, + {"Thanks, Joy", PropertyMap{"pull": "request"}}, + } + q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` + for _, m := range messages { + _, err := db.NamedExecContext(ctx, q1, m) + if err != nil { + t.Fatal(err) + } + } + var count int + err := db.GetContext(ctx, &count, "SELECT count(*) FROM message") + if err != nil { + t.Fatal(err) + } + if count != len(messages) { + t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) + } + + var m Message + err = db.GetContext(ctx, &m, "SELECT * FROM message LIMIT 1;") + if err != nil { + t.Fatal(err) + } + if m.Properties == nil { + t.Fatal("Expected m.Properties to not be nil, but it was.") + } + }) +} + +func TestIssue197Context(t *testing.T) { + // this test actually tests for a bug in database/sql: + // https://github.com/golang/go/issues/13905 + // this potentially makes _any_ named type that is an alias for []byte + // unsafe to use in a lot of different ways (basically, unsafe to hold + // onto after loading from the database). + t.Skip() + + type mybyte []byte + type Var struct{ Raw json.RawMessage } + type Var2 struct{ Raw []byte } + type Var3 struct{ Raw mybyte } + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + var err error + var v, q Var + if err = db.GetContext(ctx, &v, `SELECT '{"a": "b"}' AS raw`); err != nil { + t.Fatal(err) + } + if err = db.GetContext(ctx, &q, `SELECT 'null' AS raw`); err != nil { + t.Fatal(err) + } + + var v2, q2 Var2 + if err = db.GetContext(ctx, &v2, `SELECT '{"a": "b"}' AS raw`); err != nil { + t.Fatal(err) + } + if err = db.GetContext(ctx, &q2, `SELECT 'null' AS raw`); err != nil { + t.Fatal(err) + } + + var v3, q3 Var3 + if err = db.QueryRowContext(ctx, `SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { + t.Fatal(err) + } + if err = db.QueryRowContext(ctx, `SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { + t.Fatal(err) + } + t.Fail() + }) +} + +func TestInContext(t *testing.T) { + // some quite normal situations + type tr struct { + q string + args []interface{} + c int + } + tests := []tr{ + {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", + []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, + 7}, + {"SELECT * FROM foo WHERE x in (?)", + []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, + 8}, + } + for _, test := range tests { + q, a, err := In(test.q, test.args...) + if err != nil { + t.Error(err) + } + if len(a) != test.c { + t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) + } + if strings.Count(q, "?") != test.c { + t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) + } + } + + // too many bindVars, but no slices, so short circuits parsing + // i'm not sure if this is the right behavior; this query/arg combo + // might not work, but we shouldn't parse if we don't need to + { + orig := "SELECT * FROM foo WHERE x = ? AND y = ?" + q, a, err := In(orig, "foo", "bar", "baz") + if err != nil { + t.Error(err) + } + if len(a) != 3 { + t.Errorf("Expected 3 args, but got %d (%+v)", len(a), a) + } + if q != orig { + t.Error("Expected unchanged query.") + } + } + + tests = []tr{ + // too many bindvars; slice present so should return error during parse + {"SELECT * FROM foo WHERE x = ? and y = ?", + []interface{}{"foo", []int{1, 2, 3}, "bar"}, + 0}, + // empty slice, should return error before parse + {"SELECT * FROM foo WHERE x = ?", + []interface{}{[]int{}}, + 0}, + // too *few* bindvars, should return an error + {"SELECT * FROM foo WHERE x = ? AND y in (?)", + []interface{}{[]int{1, 2, 3}}, + 0}, + } + for _, test := range tests { + _, _, err := In(test.q, test.args...) + if err == nil { + t.Error("Expected an error, but got nil.") + } + } + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") + //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") + //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + telcodes := []int{852, 65} + q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" + query, args, err := In(q, telcodes) + if err != nil { + t.Error(err) + } + query = db.Rebind(query) + places := []Place{} + err = db.SelectContext(ctx, &places, query, args...) + if err != nil { + t.Error(err) + } + if len(places) != 2 { + t.Fatalf("Expecting 2 results, got %d", len(places)) + } + if places[0].TelCode != 65 { + t.Errorf("Expecting singapore first, but got %#v", places[0]) + } + if places[1].TelCode != 852 { + t.Errorf("Expecting hong kong second, but got %#v", places[1]) + } + }) +} + +func TestEmbeddedLiteralsContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE x ( + k text + );`, + drop: `drop table x;`, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type t1 struct { + K *string + } + type t2 struct { + Inline struct { + F string + } + K *string + } + + db.MustExecContext(ctx, db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") + + target := t1{} + err := db.GetContext(ctx, &target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + if err != nil { + t.Error(err) + } + if *target.K != "one" { + t.Error("Expected target.K to be `one`, got ", target.K) + } + + target2 := t2{} + err = db.GetContext(ctx, &target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + if err != nil { + t.Error(err) + } + if *target2.K != "one" { + t.Errorf("Expected target2.K to be `one`, got `%v`", target2.K) + } + }) +} diff --git a/vendor/github.com/jmoiron/sqlx/sqlx_test.go b/vendor/github.com/jmoiron/sqlx/sqlx_test.go index 6fa725d5..5752773a 100644 --- a/vendor/github.com/jmoiron/sqlx/sqlx_test.go +++ b/vendor/github.com/jmoiron/sqlx/sqlx_test.go @@ -591,11 +591,21 @@ func TestNilReceiver(t *testing.T) { func TestNamedQuery(t *testing.T) { var schema = Schema{ create: ` + CREATE TABLE place ( + id integer PRIMARY KEY, + name text NULL + ); CREATE TABLE person ( first_name text NULL, last_name text NULL, email text NULL ); + CREATE TABLE placeperson ( + first_name text NULL, + last_name text NULL, + email text NULL, + place_id integer NULL + ); CREATE TABLE jsperson ( "FIRST" text NULL, last_name text NULL, @@ -604,6 +614,8 @@ func TestNamedQuery(t *testing.T) { drop: ` drop table person; drop table jsperson; + drop table place; + drop table placeperson; `, } @@ -734,6 +746,76 @@ func TestNamedQuery(t *testing.T) { db.Mapper = &old + // Test nested structs + type Place struct { + ID int `db:"id"` + Name sql.NullString `db:"name"` + } + type PlacePerson struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString + Place Place `db:"place"` + } + + pl := Place{ + Name: sql.NullString{String: "myplace", Valid: true}, + } + + pp := PlacePerson{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "doe", Valid: true}, + Email: sql.NullString{String: "ben@doe.com", Valid: true}, + } + + q2 := `INSERT INTO place (id, name) VALUES (1, :name)` + _, err = db.NamedExec(q2, pl) + if err != nil { + log.Fatal(err) + } + + id := 1 + pp.Place.ID = id + + q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExec(q3, pp) + if err != nil { + log.Fatal(err) + } + + pp2 := &PlacePerson{} + rows, err = db.NamedQuery(` + SELECT + first_name, + last_name, + email, + place.id AS "place.id", + place.name AS "place.name" + FROM placeperson + INNER JOIN place ON place.id = placeperson.place_id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp2) + if err != nil { + t.Error(err) + } + if pp2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) + } + if pp2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + pp2.LastName.String) + } + if pp2.Place.Name.String != "myplace" { + t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) + } + if pp2.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) + } + } }) } @@ -885,6 +967,9 @@ func TestUsage(t *testing.T) { t.Error("Expected an error") } err = stmt1.Get(&jason, "DoesNotExist User 2") + if err == nil { + t.Fatal(err) + } stmt2, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { @@ -905,6 +990,10 @@ func TestUsage(t *testing.T) { places := []*Place{} err = db.Select(&places, "SELECT telcode FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + usa, singsing, honkers := places[0], places[1], places[2] if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { @@ -922,6 +1011,10 @@ func TestUsage(t *testing.T) { // this test also verifies that you can use either a []Struct{} or a []*Struct{} places2 := []Place{} err = db.Select(&places2, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + usa, singsing, honkers = &places2[0], &places2[1], &places2[2] // this should return a type error that &p is not a pointer to a struct slice @@ -1276,8 +1369,9 @@ func TestBindMap(t *testing.T) { type Message struct { Text string `db:"string"` - Properties PropertyMap // Stored as JSON in the database + Properties PropertyMap `db:"properties"` // Stored as JSON in the database } + type PropertyMap map[string]string // Implement driver.Valuer and sql.Scanner interfaces on PropertyMap @@ -1314,7 +1408,7 @@ func TestEmbeddedMaps(t *testing.T) { {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, {"Thanks, Joy", PropertyMap{"pull": "request"}}, } - q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties)` + q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` for _, m := range messages { _, err := db.NamedExec(q1, m) if err != nil { @@ -1324,19 +1418,19 @@ func TestEmbeddedMaps(t *testing.T) { var count int err := db.Get(&count, "SELECT count(*) FROM message") if err != nil { - t.Error(err) + t.Fatal(err) } if count != len(messages) { - t.Errorf("Expected %d messages in DB, found %d", len(messages), count) + t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) } var m Message - err = db.Get(&m, "SELECT * FROM message LIMIT 1") + err = db.Get(&m, "SELECT * FROM message LIMIT 1;") if err != nil { - t.Error(err) + t.Fatal(err) } if m.Properties == nil { - t.Error("Expected m.Properties to not be nil, but it was.") + t.Fatal("Expected m.Properties to not be nil, but it was.") } }) } @@ -1359,31 +1453,25 @@ func TestIssue197(t *testing.T) { if err = db.Get(&v, `SELECT '{"a": "b"}' AS raw`); err != nil { t.Fatal(err) } - fmt.Printf("%s: v %s\n", db.DriverName(), v.Raw) if err = db.Get(&q, `SELECT 'null' AS raw`); err != nil { t.Fatal(err) } - fmt.Printf("%s: v %s\n", db.DriverName(), v.Raw) var v2, q2 Var2 if err = db.Get(&v2, `SELECT '{"a": "b"}' AS raw`); err != nil { t.Fatal(err) } - fmt.Printf("%s: v2 %s\n", db.DriverName(), v2.Raw) if err = db.Get(&q2, `SELECT 'null' AS raw`); err != nil { t.Fatal(err) } - fmt.Printf("%s: v2 %s\n", db.DriverName(), v2.Raw) var v3, q3 Var3 if err = db.QueryRow(`SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { t.Fatal(err) } - fmt.Printf("v3 %s\n", v3.Raw) if err = db.QueryRow(`SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { t.Fatal(err) } - fmt.Printf("v3 %s\n", v3.Raw) t.Fail() }) } @@ -1649,6 +1737,36 @@ func BenchmarkIn(b *testing.B) { } } +func BenchmarkIn1k(b *testing.B) { + q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + + var vals [1000]interface{} + + for i := 0; i < b.N; i++ { + _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) + } +} + +func BenchmarkIn1kInt(b *testing.B) { + q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + + var vals [1000]int + + for i := 0; i < b.N; i++ { + _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) + } +} + +func BenchmarkIn1kString(b *testing.B) { + q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + + var vals [1000]string + + for i := 0; i < b.N; i++ { + _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) + } +} + func BenchmarkRebind(b *testing.B) { b.StopTimer() q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` diff --git a/vendor/github.com/jmoiron/sqlx/types/types.go b/vendor/github.com/jmoiron/sqlx/types/types.go index 53848bc0..7b014c1e 100644 --- a/vendor/github.com/jmoiron/sqlx/types/types.go +++ b/vendor/github.com/jmoiron/sqlx/types/types.go @@ -39,6 +39,9 @@ func (g *GzippedText) Scan(src interface{}) error { return errors.New("Incompatible type for GzippedText") } reader, err := gzip.NewReader(bytes.NewReader(source)) + if err != nil { + return err + } defer reader.Close() b, err := ioutil.ReadAll(reader) if err != nil { @@ -54,9 +57,14 @@ func (g *GzippedText) Scan(src interface{}) error { // implements `Unmarshal`, which unmarshals the json within to an interface{} type JSONText json.RawMessage +var emptyJSON = JSONText("{}") + // MarshalJSON returns the *j as the JSON encoding of j. -func (j *JSONText) MarshalJSON() ([]byte, error) { - return *j, nil +func (j JSONText) MarshalJSON() ([]byte, error) { + if len(j) == 0 { + return emptyJSON, nil + } + return j, nil } // UnmarshalJSON sets *j to a copy of data @@ -66,7 +74,6 @@ func (j *JSONText) UnmarshalJSON(data []byte) error { } *j = append((*j)[0:0], data...) return nil - } // Value returns j as a value. This does a validating unmarshal into another @@ -83,11 +90,17 @@ func (j JSONText) Value() (driver.Value, error) { // Scan stores the src in *j. No validation is done. func (j *JSONText) Scan(src interface{}) error { var source []byte - switch src.(type) { + switch t := src.(type) { case string: - source = []byte(src.(string)) + source = []byte(t) case []byte: - source = src.([]byte) + if len(t) == 0 { + source = emptyJSON + } else { + source = t + } + case nil: + *j = emptyJSON default: return errors.New("Incompatible type for JSONText") } @@ -97,10 +110,63 @@ func (j *JSONText) Scan(src interface{}) error { // Unmarshal unmarshal's the json in j to v, as in json.Unmarshal. func (j *JSONText) Unmarshal(v interface{}) error { + if len(*j) == 0 { + *j = emptyJSON + } return json.Unmarshal([]byte(*j), v) } -// Pretty printing for JSONText types +// String supports pretty printing for JSONText types. func (j JSONText) String() string { return string(j) } + +// NullJSONText represents a JSONText that may be null. +// NullJSONText implements the scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullJSONText struct { + JSONText + Valid bool // Valid is true if JSONText is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullJSONText) Scan(value interface{}) error { + if value == nil { + n.JSONText, n.Valid = emptyJSON, false + return nil + } + n.Valid = true + return n.JSONText.Scan(value) +} + +// Value implements the driver Valuer interface. +func (n NullJSONText) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.JSONText.Value() +} + +// BitBool is an implementation of a bool for the MySQL type BIT(1). +// This type allows you to avoid wasting an entire byte for MySQL's boolean type TINYINT. +type BitBool bool + +// Value implements the driver.Valuer interface, +// and turns the BitBool into a bitfield (BIT(1)) for MySQL storage. +func (b BitBool) Value() (driver.Value, error) { + if b { + return []byte{1}, nil + } + return []byte{0}, nil +} + +// Scan implements the sql.Scanner interface, +// and turns the bitfield incoming from MySQL into a BitBool +func (b *BitBool) Scan(src interface{}) error { + v, ok := src.([]byte) + if !ok { + return errors.New("bad []byte type assertion") + } + *b = v[0] == 1 + return nil +} diff --git a/vendor/github.com/jmoiron/sqlx/types/types_test.go b/vendor/github.com/jmoiron/sqlx/types/types_test.go index 78a1ef82..29813d1e 100644 --- a/vendor/github.com/jmoiron/sqlx/types/types_test.go +++ b/vendor/github.com/jmoiron/sqlx/types/types_test.go @@ -39,4 +39,89 @@ func TestJSONText(t *testing.T) { if err == nil { t.Errorf("Was expecting invalid json to fail!") } + + j = JSONText("") + v, err = j.Value() + if err != nil { + t.Errorf("Was not expecting an error") + } + + err = (&j).Scan(v) + if err != nil { + t.Errorf("Was not expecting an error") + } + + j = JSONText(nil) + v, err = j.Value() + if err != nil { + t.Errorf("Was not expecting an error") + } + + err = (&j).Scan(v) + if err != nil { + t.Errorf("Was not expecting an error") + } +} + +func TestNullJSONText(t *testing.T) { + j := NullJSONText{} + err := j.Scan(`{"foo": 1, "bar": 2}`) + if err != nil { + t.Errorf("Was not expecting an error") + } + v, err := j.Value() + if err != nil { + t.Errorf("Was not expecting an error") + } + err = (&j).Scan(v) + if err != nil { + t.Errorf("Was not expecting an error") + } + m := map[string]interface{}{} + j.Unmarshal(&m) + + if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 { + t.Errorf("Expected valid json but got some garbage instead? %#v", m) + } + + j = NullJSONText{} + err = j.Scan(nil) + if err != nil { + t.Errorf("Was not expecting an error") + } + if j.Valid != false { + t.Errorf("Expected valid to be false, but got true") + } +} + +func TestBitBool(t *testing.T) { + // Test true value + var b BitBool = true + + v, err := b.Value() + if err != nil { + t.Errorf("Cannot return error") + } + err = (&b).Scan(v) + if err != nil { + t.Errorf("Was not expecting an error") + } + if !b { + t.Errorf("Was expecting the bool we sent in (true), got %v", b) + } + + // Test false value + b = false + + v, err = b.Value() + if err != nil { + t.Errorf("Cannot return error") + } + err = (&b).Scan(v) + if err != nil { + t.Errorf("Was not expecting an error") + } + if b { + t.Errorf("Was expecting the bool we sent in (false), got %v", b) + } }