feat: env-based table name config
This commit is contained in:
parent
80f4f2f292
commit
050e125dcc
21
storage/sql/config.go
Normal file
21
storage/sql/config.go
Normal file
@ -0,0 +1,21 @@
|
||||
package sql
|
||||
|
||||
import "os"
|
||||
|
||||
var SettingsTable = "fb_settings"
|
||||
var UsersTable = "fb_users"
|
||||
var SharesTable = "fb_shares"
|
||||
|
||||
func getEnv(key string, defaultValue string) string {
|
||||
val := os.Getenv(key)
|
||||
if len(val) == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func init() {
|
||||
SettingsTable = getEnv("FILEBROWSER_SETTINGS_TABLE", SettingsTable)
|
||||
UsersTable = getEnv("FILEBROWSER_USERS_TABLE", UsersTable)
|
||||
SharesTable = getEnv("FILEBROWSER_SHARES_TABLE", SharesTable)
|
||||
}
|
||||
116
storage/sql/server.go
Normal file
116
storage/sql/server.go
Normal file
@ -0,0 +1,116 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/settings"
|
||||
)
|
||||
|
||||
var defaultServer = settings.Server{
|
||||
Port: "8080",
|
||||
Log: "stdout",
|
||||
EnableThumbnails: false,
|
||||
ResizePreview: false,
|
||||
EnableExec: false,
|
||||
TypeDetectionByHeader: false,
|
||||
}
|
||||
|
||||
func cloneServer(server settings.Server) settings.Server {
|
||||
data, err := json.Marshal(server)
|
||||
s := settings.Server{}
|
||||
if checkError(err, "Fail to clone settings.Server") {
|
||||
return s
|
||||
}
|
||||
err = json.Unmarshal(data, &s)
|
||||
checkError(err, "Fail to decode for settings.Server")
|
||||
return s
|
||||
}
|
||||
|
||||
func (s settingsBackend) GetServer() (*settings.Server, error) {
|
||||
sql := fmt.Sprintf("select key, value from %s", SettingsTable)
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query for GetServer") {
|
||||
return nil, err
|
||||
}
|
||||
server := cloneServer(defaultServer)
|
||||
key := ""
|
||||
value := ""
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&key, &value)
|
||||
if checkError(err, "Fail to query settings.Settings") {
|
||||
continue
|
||||
}
|
||||
if key == "Root" {
|
||||
server.Root = value
|
||||
} else if key == "BaseURL" {
|
||||
server.BaseURL = value
|
||||
} else if key == "Socket" {
|
||||
server.Socket = value
|
||||
} else if key == "TLSKey" {
|
||||
server.TLSKey = value
|
||||
} else if key == "TLSCert" {
|
||||
server.TLSCert = value
|
||||
} else if key == "Port" {
|
||||
server.Port = value
|
||||
} else if key == "Address" {
|
||||
server.Address = value
|
||||
} else if key == "Log" {
|
||||
server.Log = value
|
||||
} else if key == "EnableThumbnails" {
|
||||
server.EnableThumbnails = boolFromString(value)
|
||||
} else if key == "ResizePreview" {
|
||||
server.ResizePreview = boolFromString(value)
|
||||
} else if key == "EnableExec" {
|
||||
server.EnableExec = boolFromString(value)
|
||||
} else if key == "TypeDetectionByHeader" {
|
||||
server.TypeDetectionByHeader = boolFromString(value)
|
||||
} else if key == "AuthHook" {
|
||||
server.AuthHook = value
|
||||
}
|
||||
}
|
||||
return &server, nil
|
||||
}
|
||||
|
||||
func (s settingsBackend) SaveServer(ss *settings.Server) error {
|
||||
fields := []string{"Root", "BaseURL", "Socket", "TLSKey", "TLSCert", "Port", "Address", "Log", "EnableThumbnails", "ResizePreview", "EnableExec", "TypeDetectionByHeader", "AuthHook"}
|
||||
values := []string{
|
||||
ss.Root,
|
||||
ss.BaseURL,
|
||||
ss.Socket,
|
||||
ss.TLSKey,
|
||||
ss.TLSCert,
|
||||
ss.Port,
|
||||
ss.Address,
|
||||
ss.Log,
|
||||
boolToString(ss.EnableThumbnails),
|
||||
boolToString(ss.ResizePreview),
|
||||
boolToString(ss.EnableExec),
|
||||
boolToString(ss.TypeDetectionByHeader),
|
||||
ss.AuthHook}
|
||||
tx, err := s.db.Begin()
|
||||
if checkError(err, "Fail to begin db transaction") {
|
||||
return err
|
||||
}
|
||||
sql := fmt.Sprintf("INSERT INTO \"%s\" (key, value) VALUES(?,?)", SettingsTable)
|
||||
for i, field := range fields {
|
||||
stmt, err := s.db.Prepare(sql)
|
||||
defer stmt.Close()
|
||||
if checkError(err, "Fail to prepare statement") {
|
||||
tx.Rollback()
|
||||
break
|
||||
}
|
||||
_, err = stmt.Exec(field, values[i])
|
||||
if checkError(err, "Fail to insert field "+field+" of settings.Server") {
|
||||
tx.Rollback()
|
||||
break
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if checkError(err, "Fail to commit") {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/auth"
|
||||
"github.com/filebrowser/filebrowser/v2/files"
|
||||
@ -20,7 +21,7 @@ type settingsBackend struct {
|
||||
}
|
||||
|
||||
func InitSettingsTable(db *sql.DB) error {
|
||||
sql := "create table if not exists settings(key string primary key, value string)"
|
||||
sql := fmt.Sprintf("create table if not exists \"%s\"(key string primary key, value string)", SettingsTable)
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to create table settings")
|
||||
return err
|
||||
@ -135,7 +136,7 @@ func boolToString(b bool) string {
|
||||
}
|
||||
|
||||
func (s settingsBackend) Get() (*settings.Settings, error) {
|
||||
sql := "select key, value from settings"
|
||||
sql := fmt.Sprintf("select key, value from \"%s\"", SettingsTable)
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query settings.Settings") {
|
||||
return nil, err
|
||||
@ -197,9 +198,9 @@ func (s settingsBackend) Save(ss *settings.Settings) error {
|
||||
}
|
||||
for i, field := range fields {
|
||||
exists := ContainKey(s.db, field)
|
||||
sql := "INSERT INTO settings (value, key) VALUES(?,?)"
|
||||
sql := fmt.Sprintf("INSERT INTO \"%s\" (value, key) VALUES(?,?)", SettingsTable)
|
||||
if exists {
|
||||
sql = "UPDATE settings set value = ? where key = ?"
|
||||
sql = fmt.Sprintf("UPDATE \"%s\" set value = ? where key = ?", SettingsTable)
|
||||
}
|
||||
stmt, err := s.db.Prepare(sql)
|
||||
defer stmt.Close()
|
||||
@ -221,15 +222,6 @@ func (s settingsBackend) Save(ss *settings.Settings) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var defaultServer = settings.Server{
|
||||
Port: "8080",
|
||||
Log: "stdout",
|
||||
EnableThumbnails: false,
|
||||
ResizePreview: false,
|
||||
EnableExec: false,
|
||||
TypeDetectionByHeader: false,
|
||||
}
|
||||
|
||||
var defaultSettings = settings.Settings{
|
||||
Key: []byte(""),
|
||||
Signup: false,
|
||||
@ -271,17 +263,6 @@ var defaultSettings = settings.Settings{
|
||||
Rules: make([]rules.Rule, 0),
|
||||
}
|
||||
|
||||
func cloneServer(server settings.Server) settings.Server {
|
||||
data, err := json.Marshal(server)
|
||||
s := settings.Server{}
|
||||
if checkError(err, "Fail to clone settings.Server") {
|
||||
return s
|
||||
}
|
||||
err = json.Unmarshal(data, &s)
|
||||
checkError(err, "Fail to decode for settings.Server")
|
||||
return s
|
||||
}
|
||||
|
||||
func cloneSettings(s settings.Settings) settings.Settings {
|
||||
data, err := json.Marshal(s)
|
||||
s1 := settings.Settings{}
|
||||
@ -292,95 +273,8 @@ func cloneSettings(s settings.Settings) settings.Settings {
|
||||
return s1
|
||||
}
|
||||
|
||||
func (s settingsBackend) GetServer() (*settings.Server, error) {
|
||||
sql := "select key, value from settings"
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query for GetServer") {
|
||||
return nil, err
|
||||
}
|
||||
server := cloneServer(defaultServer)
|
||||
key := ""
|
||||
value := ""
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&key, &value)
|
||||
if checkError(err, "Fail to query settings.Settings") {
|
||||
continue
|
||||
}
|
||||
if key == "Root" {
|
||||
server.Root = value
|
||||
} else if key == "BaseURL" {
|
||||
server.BaseURL = value
|
||||
} else if key == "Socket" {
|
||||
server.Socket = value
|
||||
} else if key == "TLSKey" {
|
||||
server.TLSKey = value
|
||||
} else if key == "TLSCert" {
|
||||
server.TLSCert = value
|
||||
} else if key == "Port" {
|
||||
server.Port = value
|
||||
} else if key == "Address" {
|
||||
server.Address = value
|
||||
} else if key == "Log" {
|
||||
server.Log = value
|
||||
} else if key == "EnableThumbnails" {
|
||||
server.EnableThumbnails = boolFromString(value)
|
||||
} else if key == "ResizePreview" {
|
||||
server.ResizePreview = boolFromString(value)
|
||||
} else if key == "EnableExec" {
|
||||
server.EnableExec = boolFromString(value)
|
||||
} else if key == "TypeDetectionByHeader" {
|
||||
server.TypeDetectionByHeader = boolFromString(value)
|
||||
} else if key == "AuthHook" {
|
||||
server.AuthHook = value
|
||||
}
|
||||
}
|
||||
return &server, nil
|
||||
}
|
||||
|
||||
func (s settingsBackend) SaveServer(ss *settings.Server) error {
|
||||
fields := []string{"Root", "BaseURL", "Socket", "TLSKey", "TLSCert", "Port", "Address", "Log", "EnableThumbnails", "ResizePreview", "EnableExec", "TypeDetectionByHeader", "AuthHook"}
|
||||
values := []string{
|
||||
ss.Root,
|
||||
ss.BaseURL,
|
||||
ss.Socket,
|
||||
ss.TLSKey,
|
||||
ss.TLSCert,
|
||||
ss.Port,
|
||||
ss.Address,
|
||||
ss.Log,
|
||||
boolToString(ss.EnableThumbnails),
|
||||
boolToString(ss.ResizePreview),
|
||||
boolToString(ss.EnableExec),
|
||||
boolToString(ss.TypeDetectionByHeader),
|
||||
ss.AuthHook}
|
||||
tx, err := s.db.Begin()
|
||||
if checkError(err, "Fail to begin db transaction") {
|
||||
return err
|
||||
}
|
||||
for i, field := range fields {
|
||||
stmt, err := s.db.Prepare("INSERT INTO settings (key, value) VALUES(?,?)")
|
||||
defer stmt.Close()
|
||||
if checkError(err, "Fail to prepare statement") {
|
||||
tx.Rollback()
|
||||
break
|
||||
}
|
||||
_, err = stmt.Exec(field, values[i])
|
||||
if checkError(err, "Fail to insert field "+field+" of settings") {
|
||||
tx.Rollback()
|
||||
break
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if checkError(err, "Fail to commit") {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func SetSetting(db *sql.DB, key string, value string) error {
|
||||
sql := "select count(key) from settings where key = '" + key + "'"
|
||||
sql := fmt.Sprintf("select count(key) from \"%s\" where key = '%s'", SettingsTable, key)
|
||||
count := 0
|
||||
err := db.QueryRow(sql).Scan(&count)
|
||||
if checkError(err, "Fail to QueryRow for key="+key) {
|
||||
@ -393,7 +287,7 @@ func SetSetting(db *sql.DB, key string, value string) error {
|
||||
}
|
||||
|
||||
func GetSetting(db *sql.DB, key string) string {
|
||||
sql := "select value from settings where key = '" + key + "';"
|
||||
sql := fmt.Sprintf("select value from \"%s\" where key = '%s'", SettingsTable, key)
|
||||
value := ""
|
||||
err := db.QueryRow(sql).Scan(&value)
|
||||
if checkError(err, "Fail to QueryRow for key "+key) {
|
||||
@ -403,14 +297,14 @@ func GetSetting(db *sql.DB, key string) string {
|
||||
}
|
||||
|
||||
func addSetting(db *sql.DB, key string, value string) error {
|
||||
sql := "insert into settings(key, value) values('" + key + "', '" + value + "')"
|
||||
sql := fmt.Sprintf("insert into \"%s\" (key, value) values('%s', '%s')", SettingsTable, key, value)
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to addSetting")
|
||||
return err
|
||||
}
|
||||
|
||||
func updateSetting(db *sql.DB, key string, value string) error {
|
||||
sql := "update settings set value = '" + value + "' where key = '" + key + "'"
|
||||
sql := fmt.Sprintf("update \"%s\" set value = '%s' where key = '%s'", SettingsTable, value, key)
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to updateSetting")
|
||||
return err
|
||||
@ -425,7 +319,7 @@ func HadSetting(db *sql.DB) bool {
|
||||
}
|
||||
|
||||
func ContainKey(db *sql.DB, key string) bool {
|
||||
sql := "select value from settings where key = '" + key + "';"
|
||||
sql := fmt.Sprintf("select value from \"%s\" where key = '%s'", SettingsTable, key)
|
||||
value := ""
|
||||
err := db.QueryRow(sql).Scan(&value)
|
||||
if checkError(err, "Fail to QueryRow for key "+key) {
|
||||
|
||||
@ -15,10 +15,10 @@ type linkRecord interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}
|
||||
|
||||
func InitShareTable(db *sql.DB) error {
|
||||
sql := "create table if not exists share_links (hash string, path string, userid integer, expire integer, passwordhash string, token string)"
|
||||
func InitSharesTable(db *sql.DB) error {
|
||||
sql := fmt.Sprintf("create table if not exists \"%s\" (hash string, path string, userid integer, expire integer, passwordhash string, token string)", SharesTable)
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to InitShareTable")
|
||||
checkError(err, "Fail to InitSharesTable")
|
||||
return err
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ func parseLink(row linkRecord) (*share.Link, error) {
|
||||
}
|
||||
|
||||
func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) {
|
||||
sql := "select hash, path, userid, expire, passwordhash, token from share_links"
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\"", SharesTable)
|
||||
if len(condition) > 0 {
|
||||
sql = sql + " where " + condition
|
||||
}
|
||||
@ -73,12 +73,12 @@ func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) {
|
||||
}
|
||||
|
||||
func (s shareBackend) GetByHash(hash string) (*share.Link, error) {
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from share_links where hash='%s'", hash)
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where hash='%s'", SharesTable, hash)
|
||||
return parseLink(s.db.QueryRow(sql))
|
||||
}
|
||||
|
||||
func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) {
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from share_links where path='%s' and userid=%d", path, id)
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where path='%s' and userid=%d", SharesTable, path, id)
|
||||
return parseLink(s.db.QueryRow(sql))
|
||||
}
|
||||
|
||||
@ -87,13 +87,13 @@ func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) {
|
||||
return queryLinks(s.db, condition)
|
||||
}
|
||||
func (s shareBackend) Save(l *share.Link) error {
|
||||
sql := fmt.Sprintf("insert into share_links (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token)
|
||||
sql := fmt.Sprintf("insert into \"%s\" (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", SharesTable, l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to Save share")
|
||||
return err
|
||||
}
|
||||
func (s shareBackend) Delete(hash string) error {
|
||||
sql := fmt.Sprintf("DELETE FROM share_links WHERE hash='%s'", hash)
|
||||
sql := fmt.Sprintf("DELETE FROM \"%s\" WHERE hash='%s'", SharesTable, hash)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to Delete share")
|
||||
return err
|
||||
|
||||
@ -46,7 +46,7 @@ func connectDB(dbType string, path string) (*sql.DB, error) {
|
||||
func NewStorage(db *sql.DB) (*storage.Storage, error) {
|
||||
|
||||
InitUserTable(db)
|
||||
InitShareTable(db)
|
||||
InitSharesTable(db)
|
||||
InitSettingsTable(db)
|
||||
|
||||
userStore := users.NewStorage(usersBackend{db: db})
|
||||
@ -59,16 +59,6 @@ func NewStorage(db *sql.DB) (*storage.Storage, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: default
|
||||
/*
|
||||
if GetSetting(db, "auther") == "" {
|
||||
err := SetSetting(db, "auther", "json")
|
||||
if checkError(err, "Fail to set auther") {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
storage := &storage.Storage{
|
||||
Auth: authStore,
|
||||
Users: userStore,
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/errors"
|
||||
@ -116,7 +115,7 @@ func createAdminUser() users.User {
|
||||
}
|
||||
|
||||
func InitUserTable(db *sql.DB) error {
|
||||
sql := "create table if not exists users (id integer primary key, username string, password string, scope string, locale string, lockpassword bool, viewmode string, perm string, commands string, sorting string, rules string, hidedotfiles bool, dateformat bool, singleclick bool);"
|
||||
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"%s\" (id integer primary key, username string, password string, scope string, locale string, lockpassword bool, viewmode string, perm string, commands string, sorting string, rules string, hidedotfiles bool, dateformat bool, singleclick bool);", UsersTable)
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to create users table")
|
||||
return err
|
||||
@ -151,7 +150,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
||||
dateformat := false
|
||||
singleclick := false
|
||||
user := users.User{}
|
||||
sql := fmt.Sprintf("select %s from users where %s", columnsStr, conditionStr)
|
||||
sql := fmt.Sprintf("SELECT %s FROM \"%s\" WHERE %s", columnsStr, UsersTable, conditionStr)
|
||||
err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick)
|
||||
if checkError(err, "") {
|
||||
return nil, err
|
||||
@ -174,7 +173,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
||||
}
|
||||
|
||||
func (s usersBackend) Gets() ([]*users.User, error) {
|
||||
sql := "select id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules from users"
|
||||
sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM \"%s\"", UsersTable)
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query []*users.User") {
|
||||
return nil, err
|
||||
@ -218,7 +217,8 @@ func (s usersBackend) updateUser(id uint, user *users.User) error {
|
||||
lockpassword = 1
|
||||
}
|
||||
sql := fmt.Sprintf(
|
||||
"update users set username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' where id=%d",
|
||||
"UPDATE \"%s\" SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d",
|
||||
UsersTable,
|
||||
user.Username,
|
||||
user.Password,
|
||||
user.Scope,
|
||||
@ -236,10 +236,6 @@ func (s usersBackend) updateUser(id uint, user *users.User) error {
|
||||
}
|
||||
|
||||
func (s usersBackend) insertUser(user *users.User) error {
|
||||
password, err := users.HashPwd(user.Password)
|
||||
if checkError(err, "Fail to hash password") {
|
||||
return err
|
||||
}
|
||||
columnSpec := [][]string{
|
||||
{"username", "'%s'"},
|
||||
{"password", "'%s'"},
|
||||
@ -263,11 +259,11 @@ func (s usersBackend) insertUser(user *users.User) error {
|
||||
}
|
||||
columnStr := strings.Join(columns, ",")
|
||||
specStr := strings.Join(specs, ",")
|
||||
sqlFormat := fmt.Sprintf("insert into users (%s) values (%s)", columnStr, specStr)
|
||||
sqlFormat := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", UsersTable, columnStr, specStr)
|
||||
sql := fmt.Sprintf(
|
||||
sqlFormat,
|
||||
user.Username,
|
||||
password,
|
||||
user.Password,
|
||||
user.Scope,
|
||||
user.Locale,
|
||||
boolToString(user.LockPassword),
|
||||
@ -300,20 +296,23 @@ func (s usersBackend) Save(user *users.User) error {
|
||||
}
|
||||
|
||||
func (s usersBackend) DeleteByID(id uint) error {
|
||||
sql := "delete from users where id=" + strconv.Itoa(int(id))
|
||||
sql := fmt.Sprintf("delete from \"%s\" where id=%d", UsersTable, id)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to delete User by id")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s usersBackend) DeleteByUsername(username string) error {
|
||||
sql := "delete from users where username='" + username + "'"
|
||||
sql := fmt.Sprintf("delete from \"%s\" where username='%s'", UsersTable, username)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to delete user by username")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s usersBackend) Update(u *users.User, fields ...string) error {
|
||||
if len(fields) == 0 {
|
||||
return s.Save(u)
|
||||
}
|
||||
var setItems = []string{}
|
||||
for _, field := range fields {
|
||||
userField := reflect.ValueOf(u).Elem().FieldByName(field)
|
||||
@ -323,18 +322,16 @@ func (s usersBackend) Update(u *users.User, fields ...string) error {
|
||||
field = strings.ToLower(field)
|
||||
val := userField.Interface()
|
||||
typeStr := reflect.TypeOf(val).Kind().String()
|
||||
fmt.Println(typeStr)
|
||||
if typeStr == "string" {
|
||||
setItems = append(setItems, fmt.Sprintf("%s='%s'", field, val))
|
||||
setItems = append(setItems, fmt.Sprintf("\"%s\"='%s'", field, val))
|
||||
} else if typeStr == "bool" {
|
||||
setItems = append(setItems, fmt.Sprintf("%s=%s", field, boolToString(val.(bool))))
|
||||
setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, boolToString(val.(bool))))
|
||||
} else {
|
||||
// TODO
|
||||
setItems = append(setItems, fmt.Sprintf("%s=%s", field, val))
|
||||
setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, val))
|
||||
}
|
||||
}
|
||||
sql := fmt.Sprintf("update users set %s where id=%d", strings.Join(setItems, ","), u.ID)
|
||||
fmt.Println(sql)
|
||||
sql := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to update user")
|
||||
return err
|
||||
|
||||
Loading…
Reference in New Issue
Block a user