diff --git a/modules/git/blob.go b/modules/git/blob.go index 30615afe32..14ca2b1445 100644 --- a/modules/git/blob.go +++ b/modules/git/blob.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "encoding/base64" + "fmt" "io" "forgejo.org/modules/log" @@ -172,33 +173,43 @@ func (b *Blob) GetBlobContent(limit int64) (string, error) { return string(buf), err } -// GetBlobContentBase64 Reads the content of the blob with a base64 encode and returns the encoded string -func (b *Blob) GetBlobContentBase64() (string, error) { - dataRc, err := b.DataAsync() - if err != nil { - return "", err - } - defer dataRc.Close() +type BlobTooLargeError struct { + Size, Limit int64 +} - pr, pw := io.Pipe() - encoder := base64.NewEncoder(base64.StdEncoding, pw) +func (b BlobTooLargeError) Error() string { + return fmt.Sprintf("blob: content larger than limit (%d > %d)", b.Size, b.Limit) +} - go func() { - _, err := io.Copy(encoder, dataRc) - _ = encoder.Close() - - if err != nil { - _ = pw.CloseWithError(err) - } else { - _ = pw.Close() +// GetContentBase64 Reads the content of the blob and returns it as base64 encoded string. +// Returns [BlobTooLargeError] if the (unencoded) content is larger than the limit. +func (b *Blob) GetContentBase64(limit int64) (string, error) { + if b.Size() > limit { + return "", BlobTooLargeError{ + Size: b.Size(), + Limit: limit, } - }() + } - out, err := io.ReadAll(pr) + rc, size, err := b.NewTruncatedReader(limit) if err != nil { return "", err } - return string(out), nil + defer rc.Close() + + encoding := base64.StdEncoding + buf := bytes.NewBuffer(make([]byte, 0, encoding.EncodedLen(int(size)))) + + encoder := base64.NewEncoder(encoding, buf) + + if _, err := io.Copy(encoder, rc); err != nil { + return "", err + } + if err := encoder.Close(); err != nil { + return "", err + } + + return buf.String(), nil } // GuessContentType guesses the content type of the blob. diff --git a/modules/git/blob_test.go b/modules/git/blob_test.go index 54115013d3..a4b8033941 100644 --- a/modules/git/blob_test.go +++ b/modules/git/blob_test.go @@ -63,6 +63,24 @@ func TestBlob(t *testing.T) { require.Equal(t, "file2\n", r) }) + t.Run("GetContentBase64", func(t *testing.T) { + r, err := testBlob.GetContentBase64(100) + require.NoError(t, err) + require.Equal(t, "ZmlsZTIK", r) + + r, err = testBlob.GetContentBase64(-1) + require.ErrorAs(t, err, &BlobTooLargeError{}) + require.Empty(t, r) + + r, err = testBlob.GetContentBase64(4) + require.ErrorAs(t, err, &BlobTooLargeError{}) + require.Empty(t, r) + + r, err = testBlob.GetContentBase64(6) + require.NoError(t, err) + require.Equal(t, "ZmlsZTIK", r) + }) + t.Run("NewTruncatedReader", func(t *testing.T) { // read fewer than available rc, size, err := testBlob.NewTruncatedReader(100) diff --git a/routers/api/v1/repo/wiki.go b/routers/api/v1/repo/wiki.go index bb4cf0f211..7b6a00408a 100644 --- a/routers/api/v1/repo/wiki.go +++ b/routers/api/v1/repo/wiki.go @@ -5,6 +5,7 @@ package repo import ( "encoding/base64" + "errors" "fmt" "net/http" "net/url" @@ -506,11 +507,8 @@ func findWikiRepoCommit(ctx *context.APIContext) (*git.Repository, *git.Commit) // given tree entry, encoded with base64. Writes to ctx if an error occurs. func wikiContentsByEntry(ctx *context.APIContext, entry *git.TreeEntry) string { blob := entry.Blob() - if blob.Size() > setting.API.DefaultMaxBlobSize { - return "" - } - content, err := blob.GetBlobContentBase64() - if err != nil { + content, err := blob.GetContentBase64(setting.API.DefaultMaxBlobSize) + if err != nil && !errors.As(err, &git.BlobTooLargeError{}) { ctx.Error(http.StatusInternalServerError, "GetBlobContentBase64", err) return "" } diff --git a/services/repository/files/content.go b/services/repository/files/content.go index 3d2217df18..dfdee1d1df 100644 --- a/services/repository/files/content.go +++ b/services/repository/files/content.go @@ -5,6 +5,7 @@ package files import ( "context" + "errors" "fmt" "net/url" "path" @@ -273,13 +274,11 @@ func GetBlobBySHA(ctx context.Context, repo *repo_model.Repository, gitRepo *git if err != nil { return nil, err } - content := "" - if gitBlob.Size() <= setting.API.DefaultMaxBlobSize { - content, err = gitBlob.GetBlobContentBase64() - if err != nil { - return nil, err - } + content, err := gitBlob.GetContentBase64(setting.API.DefaultMaxBlobSize) + if err != nil && !errors.As(err, &git.BlobTooLargeError{}) { + return nil, err } + return &api.GitBlob{ SHA: gitBlob.ID.String(), URL: repo.APIURL() + "/git/blobs/" + url.PathEscape(gitBlob.ID.String()),