diff --git a/http/data.go b/http/data.go index e480de43..2d17543a 100644 --- a/http/data.go +++ b/http/data.go @@ -4,7 +4,6 @@ import ( "log" "net/http" "strconv" - "time" "github.com/tomasen/realip" @@ -50,7 +49,6 @@ func (d *data) Check(path string) bool { func handle(fn handleFunc, prefix string, store *storage.Storage, server *settings.Server) http.Handler { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - begin := time.Now() for k, v := range globalHeaders { w.Header().Set(k, v) } @@ -61,22 +59,12 @@ func handle(fn handleFunc, prefix string, store *storage.Storage, server *settin return } - d := data{ + status, err := fn(w, r, &data{ Runner: &runner.Runner{Enabled: server.EnableExec, Settings: settings}, store: store, settings: settings, server: server, - } - - status, err := fn(w, r, &d) - if server.EnableRequestLog { - LogRequest(w, r, server.RequestLogFormat, RequestLog{ - user: d.user, - status: status, - elapsed: time.Now().Sub(begin).Seconds(), - }) - } - + }) if status >= 400 || err != nil { clientIP := realip.FromRequest(r) log.Printf("%s: %v %s %v", r.URL.Path, status, clientIP, err) diff --git a/http/http.go b/http/http.go index 8d0c4af7..0187febf 100644 --- a/http/http.go +++ b/http/http.go @@ -39,11 +39,11 @@ func NewHandler( r = r.SkipClean(true) monkey := func(fn handleFunc, prefix string) http.Handler { - return handle(fn, prefix, store, server) + return handle(RequestLogHandleFunc(fn, server), prefix, store, server) } - r.HandleFunc("/health", healthHandler) - r.PathPrefix("/static").Handler(static) + r.HandleFunc("/health", RequestLogHandlerFunc(healthHandler, server)) + r.PathPrefix("/static").Handler(RequestLogHandler(static, server)) r.NotFoundHandler = index api := r.PathPrefix("/api").Subrouter() diff --git a/http/request_log.go b/http/request_log.go index 3d274c7b..e7b82b2d 100644 --- a/http/request_log.go +++ b/http/request_log.go @@ -8,15 +8,17 @@ import ( "strings" "time" + "github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/users" + "github.com/tomasen/realip" ) type RequestLog struct { user *users.User ip string time time.Time - request_size int64 - response_size int64 + request_size uint64 + response_size uint64 path string method string status int @@ -50,19 +52,19 @@ func (r *RequestLog) time_string() string { return r.time.Format(time.RFC3339) } -func LogRequest(w http.ResponseWriter, r *http.Request, format string, log_ RequestLog) { +func logRequest(w http.ResponseWriter, r *http.Request, format string, log_ RequestLog) { if log_.status == 0 { log_.status = 200 } - log_.ip = getRealIp(r) + log_.ip = realip.FromRequest(r) log_.time = time.Now() - log_.request_size = r.ContentLength + log_.request_size = getRequestSize(r) if log_.response_size == 0 { - log_.response_size = parseSize(w.Header().Get("Content-Length")) + log_.response_size = str2uint64(w.Header().Get("Content-Length")) } log_.origin = r.Header.Get("Origin") log_.referer = r.Header.Get("Referer") - log_.path = r.URL.Path + log_.path = r.RequestURI log_.method = r.Method log.Println(formatLog(format, log_)) } @@ -120,12 +122,12 @@ func parseFirstItem(s string) string { return items[0] } -func parseSize(d string) int64 { +func str2uint64(d string) uint64 { val, err := strconv.ParseInt(d, 10, 64) if err != nil { return 0 } - return val + return uint64(val) } func int2string(val any) string { @@ -135,3 +137,61 @@ func int2string(val any) string { func float2string(val float64) string { return fmt.Sprintf("%f", val) } + +func getRequestSize(r *http.Request) uint64 { + return uint64(r.ContentLength) +} + +type myHandler struct { + f http.HandlerFunc +} + +func (h *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.f(w, r) +} + +func handlerfunc2handler(f http.HandlerFunc) http.Handler { + return &myHandler{f: f} +} +func handler2handlerfunc(h http.Handler) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + }) +} + +func RequestLogHandlerFunc(handler http.HandlerFunc, server *settings.Server) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + begin := time.Now() + writer := MakeResponseWriterWrapper(w) + handler(writer, r) + if server.EnableRequestLog { + logRequest(w, r, server.RequestLogFormat, RequestLog{ + user: nil, + status: writer.GetStatus(), + elapsed: time.Now().Sub(begin).Seconds(), + response_size: writer.GetSize(), + }) + } + }) +} + +func RequestLogHandler(handler http.Handler, server *settings.Server) http.Handler { + return handlerfunc2handler(RequestLogHandlerFunc(handler2handlerfunc(handler), server)) +} + +func RequestLogHandleFunc(handle handleFunc, server *settings.Server) handleFunc { + return func(w http.ResponseWriter, r *http.Request, d *data) (int, error) { + begin := time.Now() + writer := MakeResponseWriterWrapper(w) + status, err := handle(writer, r, d) + if server.EnableRequestLog { + logRequest(w, r, server.RequestLogFormat, RequestLog{ + user: d.user, + status: writer.GetStatus(), + elapsed: time.Now().Sub(begin).Seconds(), + response_size: writer.GetSize(), + }) + } + return status, err + } +} diff --git a/http/response_writer_wrapper.go b/http/response_writer_wrapper.go new file mode 100644 index 00000000..aa9a09e6 --- /dev/null +++ b/http/response_writer_wrapper.go @@ -0,0 +1,66 @@ +package http + +import ( + "net/http" +) + +type _size struct { + value uint64 +} + +func (s *_size) get() uint64 { + return s.value +} +func (s *_size) set(v uint64) { + s.value = v +} +func (s *_size) add(v uint64) { + s.value += v +} + +type _status struct { + value int +} + +func (s *_status) get() int { + return s.value +} +func (s *_status) set(v int) { + s.value = v +} + +type ResponseWriterWrapper struct { + writer http.ResponseWriter + size *_size + status *_status +} + +func MakeResponseWriterWrapper(w http.ResponseWriter) *ResponseWriterWrapper { + return &ResponseWriterWrapper{ + writer: w, + size: &_size{value: 0}, + status: &_status{value: 0}, + } +} + +func (r *ResponseWriterWrapper) Write(data []byte) (int, error) { + r.size.add(uint64(len(data))) + return r.writer.Write(data) +} + +func (r *ResponseWriterWrapper) Header() http.Header { + return r.writer.Header() +} + +func (r *ResponseWriterWrapper) WriteHeader(statusCode int) { + r.status.set(statusCode) + r.writer.WriteHeader(statusCode) +} + +func (r *ResponseWriterWrapper) GetSize() uint64 { + return r.size.get() +} + +func (r *ResponseWriterWrapper) GetStatus() int { + return r.status.get() +}