diff --git a/server/handlers.go b/server/handlers.go index 4f1b16ac..e038c205 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -39,6 +39,7 @@ import ( "html" htmlTemplate "html/template" "io" + "log" "mime" "net" "net/http" @@ -399,8 +400,8 @@ func (s *Server) viewHandler(w http.ResponseWriter, r *http.Request) { s.userVoiceKey, purgeTime, maxUploadSize, - token(s.randomTokenLength), - token(s.randomTokenLength), + token(s.randomTokenLength, s.logger), + token(s.randomTokenLength, s.logger), } w.Header().Set("Vary", "Accept") @@ -449,7 +450,7 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { return } - token := token(s.randomTokenLength) + token := token(s.randomTokenLength, s.logger) w.Header().Set("Content-Type", "text/plain") @@ -514,7 +515,7 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { } } - metadata := metadataForRequest(contentType, contentLength, s.randomTokenLength, r) + metadata := metadataForRequest(contentType, contentLength, s.randomTokenLength, s.logger, r) buffer := &bytes.Buffer{} if err := json.NewEncoder(buffer).Encode(metadata); err != nil { @@ -591,14 +592,14 @@ type metadata struct { DecryptedContentType string } -func metadataForRequest(contentType string, contentLength int64, randomTokenLength int, r *http.Request) metadata { +func metadataForRequest(contentType string, contentLength int64, randomTokenLength int, logger *log.Logger, r *http.Request) metadata { metadata := metadata{ ContentType: strings.ToLower(contentType), ContentLength: contentLength, MaxDate: time.Time{}, Downloads: 0, MaxDownloads: -1, - DeletionToken: token(randomTokenLength) + token(randomTokenLength), + DeletionToken: token(randomTokenLength, logger) + token(randomTokenLength, logger), } if v := r.Header.Get("Max-Downloads"); v == "" { @@ -696,9 +697,9 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { contentType := mime.TypeByExtension(filepath.Ext(vars["filename"])) - token := token(s.randomTokenLength) + token := token(s.randomTokenLength, s.logger) - metadata := metadataForRequest(contentType, contentLength, s.randomTokenLength, r) + metadata := metadataForRequest(contentType, contentLength, s.randomTokenLength, s.logger, r) buffer := &bytes.Buffer{} if err := json.NewEncoder(buffer).Encode(metadata); err != nil { diff --git a/server/token.go b/server/token.go index f3aa012e..c079caf4 100644 --- a/server/token.go +++ b/server/token.go @@ -25,7 +25,10 @@ THE SOFTWARE. package server import ( - "math/rand" + "crypto/rand" + "log" + "math/big" + mathrand "math/rand" ) const ( @@ -34,12 +37,23 @@ const ( ) // generate a token -func token(length int) string { - result := "" +func token(length int, logger *log.Logger) string { + result := make([]byte, length) + var err error for i := 0; i < length; i++ { - x := rand.Intn(len(SYMBOLS) - 1) - result = string(SYMBOLS[x]) + result + if err == nil { + var x *big.Int + x, err = rand.Int(rand.Reader, big.NewInt(int64(len(SYMBOLS)))) + if err != nil { + logger.Printf("Fallback to math/rand instead of crypto/rand due error %s", err.Error()) + x = big.NewInt(int64(mathrand.Intn(len(SYMBOLS) - 1))) + } + result[i] = SYMBOLS[x.Int64()] + } else { // fallback to math rand + x := int64(mathrand.Intn(len(SYMBOLS) - 1)) + result[i] = SYMBOLS[x] + } } - return result + return string(result) } diff --git a/server/token_test.go b/server/token_test.go index cec3d793..fa55704f 100644 --- a/server/token_test.go +++ b/server/token_test.go @@ -1,15 +1,21 @@ package server -import "testing" +import ( + "io" + "log" + "testing" +) + +var logger = log.New(io.Discard, "", log.LstdFlags) func BenchmarkTokenConcat(b *testing.B) { for i := 0; i < b.N; i++ { - _ = token(5) + token(5) + _ = token(5, logger) + token(5, logger) } } func BenchmarkTokenLonger(b *testing.B) { for i := 0; i < b.N; i++ { - _ = token(10) + _ = token(10, logger) } }