fix: settings and users saving

This commit is contained in:
face.wsl 2022-11-25 13:16:41 +08:00
parent f7910d47f9
commit 80f4f2f292
4 changed files with 47 additions and 41 deletions

View File

@ -4,7 +4,6 @@ import (
"github.com/asdine/storm/v3" "github.com/asdine/storm/v3"
"github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/settings"
"github.com/filebrowser/filebrowser/v2/storage/sql"
) )
type settingsBackend struct { type settingsBackend struct {
@ -17,7 +16,6 @@ func (s settingsBackend) Get() (*settings.Settings, error) {
} }
func (s settingsBackend) Save(set *settings.Settings) error { func (s settingsBackend) Save(set *settings.Settings) error {
sql.LogBacktrace()
return save(s.db, "settings", set) return save(s.db, "settings", set)
} }
@ -27,6 +25,5 @@ func (s settingsBackend) GetServer() (*settings.Server, error) {
} }
func (s settingsBackend) SaveServer(server *settings.Server) error { func (s settingsBackend) SaveServer(server *settings.Server) error {
sql.LogBacktrace()
return save(s.db, "server", server) return save(s.db, "server", server)
} }

View File

@ -15,7 +15,6 @@ type authBackend struct {
} }
func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) { func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) {
logBacktrace()
var auther auth.Auther var auther auth.Auther
switch t { switch t {

View File

@ -2,6 +2,7 @@ package sql
import ( import (
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"github.com/filebrowser/filebrowser/v2/auth" "github.com/filebrowser/filebrowser/v2/auth"
@ -25,6 +26,14 @@ func InitSettingsTable(db *sql.DB) error {
return err return err
} }
func bytesToString(data []byte) string {
return base64.RawStdEncoding.EncodeToString(data)
}
func bytesFromString(s string) ([]byte, error) {
return base64.RawStdEncoding.DecodeString(s)
}
func userDefaultsFromString(s string) settings.UserDefaults { func userDefaultsFromString(s string) settings.UserDefaults {
if s == "" { if s == "" {
return settings.UserDefaults{} return settings.UserDefaults{}
@ -126,7 +135,6 @@ func boolToString(b bool) string {
} }
func (s settingsBackend) Get() (*settings.Settings, error) { func (s settingsBackend) Get() (*settings.Settings, error) {
logBacktrace()
sql := "select key, value from settings" sql := "select key, value from settings"
rows, err := s.db.Query(sql) rows, err := s.db.Query(sql)
if checkError(err, "Fail to Query settings.Settings") { if checkError(err, "Fail to Query settings.Settings") {
@ -139,11 +147,16 @@ func (s settingsBackend) Get() (*settings.Settings, error) {
err = rows.Scan(&key, &value) err = rows.Scan(&key, &value)
checkError(err, "Fail to query settings.Settings") checkError(err, "Fail to query settings.Settings")
if key == "Key" { if key == "Key" {
settings1.Key = []byte(value) val, err := bytesFromString(value)
if !checkError(err, "Fail to parse []byte from string") {
settings1.Key = val
}
} else if key == "Signup" { } else if key == "Signup" {
settings1.Signup = boolFromString(value) settings1.Signup = boolFromString(value)
} else if key == "CreateUserDir" { } else if key == "CreateUserDir" {
settings1.CreateUserDir = boolFromString(value) settings1.CreateUserDir = boolFromString(value)
} else if key == "UserHomeBasePath" {
settings1.UserHomeBasePath = value
} else if key == "Defaults" { } else if key == "Defaults" {
settings1.Defaults = userDefaultsFromString(value) settings1.Defaults = userDefaultsFromString(value)
} else if key == "AuthMethod" { } else if key == "AuthMethod" {
@ -165,13 +178,12 @@ func (s settingsBackend) Get() (*settings.Settings, error) {
} }
func (s settingsBackend) Save(ss *settings.Settings) error { func (s settingsBackend) Save(ss *settings.Settings) error {
logBacktrace()
fields := []string{"Key", "Signup", "CreateUserDir", "UserHomeBasePath", "Defaults", "AuthMethod", "Branding", "Commands", "Shell", "Rules"} fields := []string{"Key", "Signup", "CreateUserDir", "UserHomeBasePath", "Defaults", "AuthMethod", "Branding", "Commands", "Shell", "Rules"}
values := []string{ values := []string{
string(ss.Key), bytesToString(ss.Key),
boolToString(ss.Signup), boolToString(ss.Signup),
boolToString(ss.CreateUserDir), boolToString(ss.CreateUserDir),
string(ss.UserHomeBasePath), ss.UserHomeBasePath,
userDefaultsToString(ss.Defaults), userDefaultsToString(ss.Defaults),
string(ss.AuthMethod), string(ss.AuthMethod),
brandingToString(ss.Branding), brandingToString(ss.Branding),
@ -184,13 +196,18 @@ func (s settingsBackend) Save(ss *settings.Settings) error {
return err return err
} }
for i, field := range fields { for i, field := range fields {
stmt, err := s.db.Prepare("INSERT INTO settings (key, value) VALUES(?,?)") exists := ContainKey(s.db, field)
sql := "INSERT INTO settings (value, key) VALUES(?,?)"
if exists {
sql = "UPDATE settings set value = ? where key = ?"
}
stmt, err := s.db.Prepare(sql)
defer stmt.Close() defer stmt.Close()
if checkError(err, "Fail to prepare statement") { if checkError(err, "Fail to prepare statement") {
tx.Rollback() tx.Rollback()
break break
} }
_, err = stmt.Exec(field, values[i]) _, err = stmt.Exec(values[i], field)
if checkError(err, "Fail to insert field "+field+" of settings") { if checkError(err, "Fail to insert field "+field+" of settings") {
tx.Rollback() tx.Rollback()
break break
@ -276,7 +293,6 @@ func cloneSettings(s settings.Settings) settings.Settings {
} }
func (s settingsBackend) GetServer() (*settings.Server, error) { func (s settingsBackend) GetServer() (*settings.Server, error) {
logBacktrace()
sql := "select key, value from settings" sql := "select key, value from settings"
rows, err := s.db.Query(sql) rows, err := s.db.Query(sql)
if checkError(err, "Fail to Query for GetServer") { if checkError(err, "Fail to Query for GetServer") {
@ -323,7 +339,6 @@ func (s settingsBackend) GetServer() (*settings.Server, error) {
} }
func (s settingsBackend) SaveServer(ss *settings.Server) error { func (s settingsBackend) SaveServer(ss *settings.Server) error {
logBacktrace()
fields := []string{"Root", "BaseURL", "Socket", "TLSKey", "TLSCert", "Port", "Address", "Log", "EnableThumbnails", "ResizePreview", "EnableExec", "TypeDetectionByHeader", "AuthHook"} fields := []string{"Root", "BaseURL", "Socket", "TLSKey", "TLSCert", "Port", "Address", "Log", "EnableThumbnails", "ResizePreview", "EnableExec", "TypeDetectionByHeader", "AuthHook"}
values := []string{ values := []string{
ss.Root, ss.Root,
@ -408,3 +423,17 @@ func HadSetting(db *sql.DB) bool {
} }
return true return true
} }
func ContainKey(db *sql.DB, key string) bool {
sql := "select value from settings where key = '" + key + "';"
value := ""
err := db.QueryRow(sql).Scan(&value)
if checkError(err, "Fail to QueryRow for key "+key) {
return false
}
return true
}
func HadSettingOfKey(db *sql.DB, key string) bool {
return GetSetting(db, "Key") == key
}

View File

@ -4,7 +4,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
@ -117,24 +116,13 @@ func createAdminUser() users.User {
} }
func InitUserTable(db *sql.DB) error { func InitUserTable(db *sql.DB) error {
logBacktrace()
sql := "create table if not exists users (id integer primary key, username string, password string, scope string, locale string, lockpassword bool, viewmode string, perm string, commands string, sorting string, rules string, hidedotfiles bool, dateformat bool, singleclick bool);" sql := "create table if not exists users (id integer primary key, username string, password string, scope string, locale string, lockpassword bool, viewmode string, perm string, commands string, sorting string, rules string, hidedotfiles bool, dateformat bool, singleclick bool);"
_, err := db.Exec(sql) _, err := db.Exec(sql)
if checkError(err, "Fail to create users table") { checkError(err, "Fail to create users table")
return err
}
user, err := usersBackend{db}.Get("admin")
checkError(err, "Fail to query admin user")
if user == nil {
log.Println("No admin exists")
err := usersBackend{db}.Save(&adminUser)
checkError(err, "Fail to init admin user")
}
return err return err
} }
func (s usersBackend) Get(i interface{}) (*users.User, error) { func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
logBacktrace()
columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"} columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"}
columnsStr := strings.Join(columns, ",") columnsStr := strings.Join(columns, ",")
var conditionStr string var conditionStr string
@ -186,7 +174,6 @@ func (s usersBackend) Get(i interface{}) (*users.User, error) {
} }
func (s usersBackend) Gets() ([]*users.User, error) { func (s usersBackend) Gets() ([]*users.User, error) {
logBacktrace()
sql := "select id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules from users" sql := "select id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules from users"
rows, err := s.db.Query(sql) rows, err := s.db.Query(sql)
if checkError(err, "Fail to Query []*users.User") { if checkError(err, "Fail to Query []*users.User") {
@ -225,13 +212,7 @@ func (s usersBackend) Gets() ([]*users.User, error) {
return users2, nil return users2, nil
} }
func (s usersBackend) GetBy(id interface{}) (*users.User, error) {
logBacktrace()
return s.Get(id)
}
func (s usersBackend) updateUser(id uint, user *users.User) error { func (s usersBackend) updateUser(id uint, user *users.User) error {
logBacktrace()
lockpassword := 0 lockpassword := 0
if user.LockPassword { if user.LockPassword {
lockpassword = 1 lockpassword = 1
@ -255,7 +236,6 @@ func (s usersBackend) updateUser(id uint, user *users.User) error {
} }
func (s usersBackend) insertUser(user *users.User) error { func (s usersBackend) insertUser(user *users.User) error {
logBacktrace()
password, err := users.HashPwd(user.Password) password, err := users.HashPwd(user.Password)
if checkError(err, "Fail to hash password") { if checkError(err, "Fail to hash password") {
return err return err
@ -300,13 +280,17 @@ func (s usersBackend) insertUser(user *users.User) error {
boolToString(user.DateFormat), boolToString(user.DateFormat),
boolToString(user.SingleClick), boolToString(user.SingleClick),
) )
_, err = s.db.Exec(sql) res, err := s.db.Exec(sql)
checkError(err, "Fail to insert user") if !checkError(err, "Fail to insert user") {
id, err2 := res.LastInsertId()
if !checkError(err2, "Fail to fetch last insert id") {
user.ID = uint(id)
}
}
return err return err
} }
func (s usersBackend) Save(user *users.User) error { func (s usersBackend) Save(user *users.User) error {
logBacktrace()
userOriginal, err := s.GetBy(user.Username) userOriginal, err := s.GetBy(user.Username)
checkError(err, "") checkError(err, "")
if userOriginal != nil { if userOriginal != nil {
@ -316,7 +300,6 @@ func (s usersBackend) Save(user *users.User) error {
} }
func (s usersBackend) DeleteByID(id uint) error { func (s usersBackend) DeleteByID(id uint) error {
logBacktrace()
sql := "delete from users where id=" + strconv.Itoa(int(id)) sql := "delete from users where id=" + strconv.Itoa(int(id))
_, err := s.db.Exec(sql) _, err := s.db.Exec(sql)
checkError(err, "Fail to delete User by id") checkError(err, "Fail to delete User by id")
@ -324,7 +307,6 @@ func (s usersBackend) DeleteByID(id uint) error {
} }
func (s usersBackend) DeleteByUsername(username string) error { func (s usersBackend) DeleteByUsername(username string) error {
logBacktrace()
sql := "delete from users where username='" + username + "'" sql := "delete from users where username='" + username + "'"
_, err := s.db.Exec(sql) _, err := s.db.Exec(sql)
checkError(err, "Fail to delete user by username") checkError(err, "Fail to delete user by username")
@ -332,7 +314,6 @@ func (s usersBackend) DeleteByUsername(username string) error {
} }
func (s usersBackend) Update(u *users.User, fields ...string) error { func (s usersBackend) Update(u *users.User, fields ...string) error {
logBacktrace()
var setItems = []string{} var setItems = []string{}
for _, field := range fields { for _, field := range fields {
userField := reflect.ValueOf(u).Elem().FieldByName(field) userField := reflect.ValueOf(u).Elem().FieldByName(field)