From 050e125dcc6022427ee802c0b193043045aaf93d Mon Sep 17 00:00:00 2001 From: "face.wsl" Date: Sat, 26 Nov 2022 00:33:13 +0800 Subject: [PATCH] feat: env-based table name config --- storage/sql/config.go | 21 +++++++ storage/sql/server.go | 116 ++++++++++++++++++++++++++++++++++++ storage/sql/settings.go | 126 ++++------------------------------------ storage/sql/share.go | 16 ++--- storage/sql/sql.go | 12 +--- storage/sql/users.go | 35 +++++------ 6 files changed, 172 insertions(+), 154 deletions(-) create mode 100644 storage/sql/config.go create mode 100644 storage/sql/server.go diff --git a/storage/sql/config.go b/storage/sql/config.go new file mode 100644 index 00000000..3106884c --- /dev/null +++ b/storage/sql/config.go @@ -0,0 +1,21 @@ +package sql + +import "os" + +var SettingsTable = "fb_settings" +var UsersTable = "fb_users" +var SharesTable = "fb_shares" + +func getEnv(key string, defaultValue string) string { + val := os.Getenv(key) + if len(val) == 0 { + return defaultValue + } + return val +} + +func init() { + SettingsTable = getEnv("FILEBROWSER_SETTINGS_TABLE", SettingsTable) + UsersTable = getEnv("FILEBROWSER_USERS_TABLE", UsersTable) + SharesTable = getEnv("FILEBROWSER_SHARES_TABLE", SharesTable) +} diff --git a/storage/sql/server.go b/storage/sql/server.go new file mode 100644 index 00000000..f75a1e89 --- /dev/null +++ b/storage/sql/server.go @@ -0,0 +1,116 @@ +package sql + +import ( + "encoding/json" + "fmt" + + "github.com/filebrowser/filebrowser/v2/settings" +) + +var defaultServer = settings.Server{ + Port: "8080", + Log: "stdout", + EnableThumbnails: false, + ResizePreview: false, + EnableExec: false, + TypeDetectionByHeader: false, +} + +func cloneServer(server settings.Server) settings.Server { + data, err := json.Marshal(server) + s := settings.Server{} + if checkError(err, "Fail to clone settings.Server") { + return s + } + err = json.Unmarshal(data, &s) + checkError(err, "Fail to decode for settings.Server") + return s +} + +func (s settingsBackend) GetServer() (*settings.Server, error) { + sql := fmt.Sprintf("select key, value from %s", SettingsTable) + rows, err := s.db.Query(sql) + if checkError(err, "Fail to Query for GetServer") { + return nil, err + } + server := cloneServer(defaultServer) + key := "" + value := "" + + for rows.Next() { + err = rows.Scan(&key, &value) + if checkError(err, "Fail to query settings.Settings") { + continue + } + if key == "Root" { + server.Root = value + } else if key == "BaseURL" { + server.BaseURL = value + } else if key == "Socket" { + server.Socket = value + } else if key == "TLSKey" { + server.TLSKey = value + } else if key == "TLSCert" { + server.TLSCert = value + } else if key == "Port" { + server.Port = value + } else if key == "Address" { + server.Address = value + } else if key == "Log" { + server.Log = value + } else if key == "EnableThumbnails" { + server.EnableThumbnails = boolFromString(value) + } else if key == "ResizePreview" { + server.ResizePreview = boolFromString(value) + } else if key == "EnableExec" { + server.EnableExec = boolFromString(value) + } else if key == "TypeDetectionByHeader" { + server.TypeDetectionByHeader = boolFromString(value) + } else if key == "AuthHook" { + server.AuthHook = value + } + } + return &server, nil +} + +func (s settingsBackend) SaveServer(ss *settings.Server) error { + fields := []string{"Root", "BaseURL", "Socket", "TLSKey", "TLSCert", "Port", "Address", "Log", "EnableThumbnails", "ResizePreview", "EnableExec", "TypeDetectionByHeader", "AuthHook"} + values := []string{ + ss.Root, + ss.BaseURL, + ss.Socket, + ss.TLSKey, + ss.TLSCert, + ss.Port, + ss.Address, + ss.Log, + boolToString(ss.EnableThumbnails), + boolToString(ss.ResizePreview), + boolToString(ss.EnableExec), + boolToString(ss.TypeDetectionByHeader), + ss.AuthHook} + tx, err := s.db.Begin() + if checkError(err, "Fail to begin db transaction") { + return err + } + sql := fmt.Sprintf("INSERT INTO \"%s\" (key, value) VALUES(?,?)", SettingsTable) + for i, field := range fields { + stmt, err := s.db.Prepare(sql) + defer stmt.Close() + if checkError(err, "Fail to prepare statement") { + tx.Rollback() + break + } + _, err = stmt.Exec(field, values[i]) + if checkError(err, "Fail to insert field "+field+" of settings.Server") { + tx.Rollback() + break + } + } + err = tx.Commit() + if checkError(err, "Fail to commit") { + tx.Rollback() + return err + } + return err +} diff --git a/storage/sql/settings.go b/storage/sql/settings.go index d24e6e87..54737dee 100644 --- a/storage/sql/settings.go +++ b/storage/sql/settings.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/base64" "encoding/json" + "fmt" "github.com/filebrowser/filebrowser/v2/auth" "github.com/filebrowser/filebrowser/v2/files" @@ -20,7 +21,7 @@ type settingsBackend struct { } func InitSettingsTable(db *sql.DB) error { - sql := "create table if not exists settings(key string primary key, value string)" + sql := fmt.Sprintf("create table if not exists \"%s\"(key string primary key, value string)", SettingsTable) _, err := db.Exec(sql) checkError(err, "Fail to create table settings") return err @@ -135,7 +136,7 @@ func boolToString(b bool) string { } func (s settingsBackend) Get() (*settings.Settings, error) { - sql := "select key, value from settings" + sql := fmt.Sprintf("select key, value from \"%s\"", SettingsTable) rows, err := s.db.Query(sql) if checkError(err, "Fail to Query settings.Settings") { return nil, err @@ -197,9 +198,9 @@ func (s settingsBackend) Save(ss *settings.Settings) error { } for i, field := range fields { exists := ContainKey(s.db, field) - sql := "INSERT INTO settings (value, key) VALUES(?,?)" + sql := fmt.Sprintf("INSERT INTO \"%s\" (value, key) VALUES(?,?)", SettingsTable) if exists { - sql = "UPDATE settings set value = ? where key = ?" + sql = fmt.Sprintf("UPDATE \"%s\" set value = ? where key = ?", SettingsTable) } stmt, err := s.db.Prepare(sql) defer stmt.Close() @@ -221,15 +222,6 @@ func (s settingsBackend) Save(ss *settings.Settings) error { return err } -var defaultServer = settings.Server{ - Port: "8080", - Log: "stdout", - EnableThumbnails: false, - ResizePreview: false, - EnableExec: false, - TypeDetectionByHeader: false, -} - var defaultSettings = settings.Settings{ Key: []byte(""), Signup: false, @@ -271,17 +263,6 @@ var defaultSettings = settings.Settings{ Rules: make([]rules.Rule, 0), } -func cloneServer(server settings.Server) settings.Server { - data, err := json.Marshal(server) - s := settings.Server{} - if checkError(err, "Fail to clone settings.Server") { - return s - } - err = json.Unmarshal(data, &s) - checkError(err, "Fail to decode for settings.Server") - return s -} - func cloneSettings(s settings.Settings) settings.Settings { data, err := json.Marshal(s) s1 := settings.Settings{} @@ -292,95 +273,8 @@ func cloneSettings(s settings.Settings) settings.Settings { return s1 } -func (s settingsBackend) GetServer() (*settings.Server, error) { - sql := "select key, value from settings" - rows, err := s.db.Query(sql) - if checkError(err, "Fail to Query for GetServer") { - return nil, err - } - server := cloneServer(defaultServer) - key := "" - value := "" - - for rows.Next() { - err = rows.Scan(&key, &value) - if checkError(err, "Fail to query settings.Settings") { - continue - } - if key == "Root" { - server.Root = value - } else if key == "BaseURL" { - server.BaseURL = value - } else if key == "Socket" { - server.Socket = value - } else if key == "TLSKey" { - server.TLSKey = value - } else if key == "TLSCert" { - server.TLSCert = value - } else if key == "Port" { - server.Port = value - } else if key == "Address" { - server.Address = value - } else if key == "Log" { - server.Log = value - } else if key == "EnableThumbnails" { - server.EnableThumbnails = boolFromString(value) - } else if key == "ResizePreview" { - server.ResizePreview = boolFromString(value) - } else if key == "EnableExec" { - server.EnableExec = boolFromString(value) - } else if key == "TypeDetectionByHeader" { - server.TypeDetectionByHeader = boolFromString(value) - } else if key == "AuthHook" { - server.AuthHook = value - } - } - return &server, nil -} - -func (s settingsBackend) SaveServer(ss *settings.Server) error { - fields := []string{"Root", "BaseURL", "Socket", "TLSKey", "TLSCert", "Port", "Address", "Log", "EnableThumbnails", "ResizePreview", "EnableExec", "TypeDetectionByHeader", "AuthHook"} - values := []string{ - ss.Root, - ss.BaseURL, - ss.Socket, - ss.TLSKey, - ss.TLSCert, - ss.Port, - ss.Address, - ss.Log, - boolToString(ss.EnableThumbnails), - boolToString(ss.ResizePreview), - boolToString(ss.EnableExec), - boolToString(ss.TypeDetectionByHeader), - ss.AuthHook} - tx, err := s.db.Begin() - if checkError(err, "Fail to begin db transaction") { - return err - } - for i, field := range fields { - stmt, err := s.db.Prepare("INSERT INTO settings (key, value) VALUES(?,?)") - defer stmt.Close() - if checkError(err, "Fail to prepare statement") { - tx.Rollback() - break - } - _, err = stmt.Exec(field, values[i]) - if checkError(err, "Fail to insert field "+field+" of settings") { - tx.Rollback() - break - } - } - err = tx.Commit() - if checkError(err, "Fail to commit") { - tx.Rollback() - return err - } - return err -} - func SetSetting(db *sql.DB, key string, value string) error { - sql := "select count(key) from settings where key = '" + key + "'" + sql := fmt.Sprintf("select count(key) from \"%s\" where key = '%s'", SettingsTable, key) count := 0 err := db.QueryRow(sql).Scan(&count) if checkError(err, "Fail to QueryRow for key="+key) { @@ -393,7 +287,7 @@ func SetSetting(db *sql.DB, key string, value string) error { } func GetSetting(db *sql.DB, key string) string { - sql := "select value from settings where key = '" + key + "';" + sql := fmt.Sprintf("select value from \"%s\" where key = '%s'", SettingsTable, key) value := "" err := db.QueryRow(sql).Scan(&value) if checkError(err, "Fail to QueryRow for key "+key) { @@ -403,14 +297,14 @@ func GetSetting(db *sql.DB, key string) string { } func addSetting(db *sql.DB, key string, value string) error { - sql := "insert into settings(key, value) values('" + key + "', '" + value + "')" + sql := fmt.Sprintf("insert into \"%s\" (key, value) values('%s', '%s')", SettingsTable, key, value) _, err := db.Exec(sql) checkError(err, "Fail to addSetting") return err } func updateSetting(db *sql.DB, key string, value string) error { - sql := "update settings set value = '" + value + "' where key = '" + key + "'" + sql := fmt.Sprintf("update \"%s\" set value = '%s' where key = '%s'", SettingsTable, value, key) _, err := db.Exec(sql) checkError(err, "Fail to updateSetting") return err @@ -425,7 +319,7 @@ func HadSetting(db *sql.DB) bool { } func ContainKey(db *sql.DB, key string) bool { - sql := "select value from settings where key = '" + key + "';" + sql := fmt.Sprintf("select value from \"%s\" where key = '%s'", SettingsTable, key) value := "" err := db.QueryRow(sql).Scan(&value) if checkError(err, "Fail to QueryRow for key "+key) { diff --git a/storage/sql/share.go b/storage/sql/share.go index 21e3182e..03630d67 100644 --- a/storage/sql/share.go +++ b/storage/sql/share.go @@ -15,10 +15,10 @@ type linkRecord interface { Scan(dest ...interface{}) error } -func InitShareTable(db *sql.DB) error { - sql := "create table if not exists share_links (hash string, path string, userid integer, expire integer, passwordhash string, token string)" +func InitSharesTable(db *sql.DB) error { + sql := fmt.Sprintf("create table if not exists \"%s\" (hash string, path string, userid integer, expire integer, passwordhash string, token string)", SharesTable) _, err := db.Exec(sql) - checkError(err, "Fail to InitShareTable") + checkError(err, "Fail to InitSharesTable") return err } @@ -44,7 +44,7 @@ func parseLink(row linkRecord) (*share.Link, error) { } func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) { - sql := "select hash, path, userid, expire, passwordhash, token from share_links" + sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\"", SharesTable) if len(condition) > 0 { sql = sql + " where " + condition } @@ -73,12 +73,12 @@ func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) { } func (s shareBackend) GetByHash(hash string) (*share.Link, error) { - sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from share_links where hash='%s'", hash) + sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where hash='%s'", SharesTable, hash) return parseLink(s.db.QueryRow(sql)) } func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) { - sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from share_links where path='%s' and userid=%d", path, id) + sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where path='%s' and userid=%d", SharesTable, path, id) return parseLink(s.db.QueryRow(sql)) } @@ -87,13 +87,13 @@ func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) { return queryLinks(s.db, condition) } func (s shareBackend) Save(l *share.Link) error { - sql := fmt.Sprintf("insert into share_links (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token) + sql := fmt.Sprintf("insert into \"%s\" (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", SharesTable, l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token) _, err := s.db.Exec(sql) checkError(err, "Fail to Save share") return err } func (s shareBackend) Delete(hash string) error { - sql := fmt.Sprintf("DELETE FROM share_links WHERE hash='%s'", hash) + sql := fmt.Sprintf("DELETE FROM \"%s\" WHERE hash='%s'", SharesTable, hash) _, err := s.db.Exec(sql) checkError(err, "Fail to Delete share") return err diff --git a/storage/sql/sql.go b/storage/sql/sql.go index f0698061..16ba7cc6 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -46,7 +46,7 @@ func connectDB(dbType string, path string) (*sql.DB, error) { func NewStorage(db *sql.DB) (*storage.Storage, error) { InitUserTable(db) - InitShareTable(db) + InitSharesTable(db) InitSettingsTable(db) userStore := users.NewStorage(usersBackend{db: db}) @@ -59,16 +59,6 @@ func NewStorage(db *sql.DB) (*storage.Storage, error) { return nil, err } - // TODO: default - /* - if GetSetting(db, "auther") == "" { - err := SetSetting(db, "auther", "json") - if checkError(err, "Fail to set auther") { - return nil, err - } - } - */ - storage := &storage.Storage{ Auth: authStore, Users: userStore, diff --git a/storage/sql/users.go b/storage/sql/users.go index d24a2e83..2d7269aa 100644 --- a/storage/sql/users.go +++ b/storage/sql/users.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "reflect" - "strconv" "strings" "github.com/filebrowser/filebrowser/v2/errors" @@ -116,7 +115,7 @@ func createAdminUser() users.User { } func InitUserTable(db *sql.DB) error { - sql := "create table if not exists users (id integer primary key, username string, password string, scope string, locale string, lockpassword bool, viewmode string, perm string, commands string, sorting string, rules string, hidedotfiles bool, dateformat bool, singleclick bool);" + sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"%s\" (id integer primary key, username string, password string, scope string, locale string, lockpassword bool, viewmode string, perm string, commands string, sorting string, rules string, hidedotfiles bool, dateformat bool, singleclick bool);", UsersTable) _, err := db.Exec(sql) checkError(err, "Fail to create users table") return err @@ -151,7 +150,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) { dateformat := false singleclick := false user := users.User{} - sql := fmt.Sprintf("select %s from users where %s", columnsStr, conditionStr) + sql := fmt.Sprintf("SELECT %s FROM \"%s\" WHERE %s", columnsStr, UsersTable, conditionStr) err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick) if checkError(err, "") { return nil, err @@ -174,7 +173,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) { } func (s usersBackend) Gets() ([]*users.User, error) { - sql := "select id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules from users" + sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM \"%s\"", UsersTable) rows, err := s.db.Query(sql) if checkError(err, "Fail to Query []*users.User") { return nil, err @@ -218,7 +217,8 @@ func (s usersBackend) updateUser(id uint, user *users.User) error { lockpassword = 1 } sql := fmt.Sprintf( - "update users set username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' where id=%d", + "UPDATE \"%s\" SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d", + UsersTable, user.Username, user.Password, user.Scope, @@ -236,10 +236,6 @@ func (s usersBackend) updateUser(id uint, user *users.User) error { } func (s usersBackend) insertUser(user *users.User) error { - password, err := users.HashPwd(user.Password) - if checkError(err, "Fail to hash password") { - return err - } columnSpec := [][]string{ {"username", "'%s'"}, {"password", "'%s'"}, @@ -263,11 +259,11 @@ func (s usersBackend) insertUser(user *users.User) error { } columnStr := strings.Join(columns, ",") specStr := strings.Join(specs, ",") - sqlFormat := fmt.Sprintf("insert into users (%s) values (%s)", columnStr, specStr) + sqlFormat := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", UsersTable, columnStr, specStr) sql := fmt.Sprintf( sqlFormat, user.Username, - password, + user.Password, user.Scope, user.Locale, boolToString(user.LockPassword), @@ -300,20 +296,23 @@ func (s usersBackend) Save(user *users.User) error { } func (s usersBackend) DeleteByID(id uint) error { - sql := "delete from users where id=" + strconv.Itoa(int(id)) + sql := fmt.Sprintf("delete from \"%s\" where id=%d", UsersTable, id) _, err := s.db.Exec(sql) checkError(err, "Fail to delete User by id") return err } func (s usersBackend) DeleteByUsername(username string) error { - sql := "delete from users where username='" + username + "'" + sql := fmt.Sprintf("delete from \"%s\" where username='%s'", UsersTable, username) _, err := s.db.Exec(sql) checkError(err, "Fail to delete user by username") return err } func (s usersBackend) Update(u *users.User, fields ...string) error { + if len(fields) == 0 { + return s.Save(u) + } var setItems = []string{} for _, field := range fields { userField := reflect.ValueOf(u).Elem().FieldByName(field) @@ -323,18 +322,16 @@ func (s usersBackend) Update(u *users.User, fields ...string) error { field = strings.ToLower(field) val := userField.Interface() typeStr := reflect.TypeOf(val).Kind().String() - fmt.Println(typeStr) if typeStr == "string" { - setItems = append(setItems, fmt.Sprintf("%s='%s'", field, val)) + setItems = append(setItems, fmt.Sprintf("\"%s\"='%s'", field, val)) } else if typeStr == "bool" { - setItems = append(setItems, fmt.Sprintf("%s=%s", field, boolToString(val.(bool)))) + setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, boolToString(val.(bool)))) } else { // TODO - setItems = append(setItems, fmt.Sprintf("%s=%s", field, val)) + setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, val)) } } - sql := fmt.Sprintf("update users set %s where id=%d", strings.Join(setItems, ","), u.ID) - fmt.Println(sql) + sql := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID) _, err := s.db.Exec(sql) checkError(err, "Fail to update user") return err