feat: use viper for flag/env/config handling

This commit is contained in:
Henrique Dias 2025-11-15 09:08:08 +01:00
parent c4c1cea230
commit 65b2ef42d9
No known key found for this signature in database
13 changed files with 172 additions and 399 deletions

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
v "github.com/spf13/viper"
) )
func init() { func init() {
@ -19,11 +20,7 @@ var cmdsLsCmd = &cobra.Command{
if err != nil { if err != nil {
return err return err
} }
evt, err := getString(cmd.Flags(), "event") evt := v.GetString("event")
if err != nil {
return err
}
if evt == "" { if evt == "" {
printEvents(s.Commands) printEvents(s.Commands)
} else { } else {

View File

@ -10,6 +10,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" "github.com/spf13/pflag"
v "github.com/spf13/viper"
"github.com/filebrowser/filebrowser/v2/auth" "github.com/filebrowser/filebrowser/v2/auth"
"github.com/filebrowser/filebrowser/v2/errors" "github.com/filebrowser/filebrowser/v2/errors"
@ -56,11 +57,8 @@ func addConfigFlags(flags *pflag.FlagSet) {
flags.String("dir-mode", fmt.Sprintf("%O", settings.DefaultDirMode), "Mode bits that new directories are created with") flags.String("dir-mode", fmt.Sprintf("%O", settings.DefaultDirMode), "Mode bits that new directories are created with")
} }
func getAuthMethod(flags *pflag.FlagSet, defaults ...interface{}) (settings.AuthMethod, map[string]interface{}, error) { func getAuthMethod(defaults ...interface{}) (settings.AuthMethod, map[string]interface{}, error) {
methodStr, err := getString(flags, "auth.method") methodStr := v.GetString("auth.method")
if err != nil {
return "", nil, err
}
method := settings.AuthMethod(methodStr) method := settings.AuthMethod(methodStr)
var defaultAuther map[string]interface{} var defaultAuther map[string]interface{}
@ -87,12 +85,8 @@ func getAuthMethod(flags *pflag.FlagSet, defaults ...interface{}) (settings.Auth
return method, defaultAuther, nil return method, defaultAuther, nil
} }
func getProxyAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (auth.Auther, error) { func getProxyAuth(defaultAuther map[string]interface{}) (auth.Auther, error) {
header, err := getString(flags, "auth.header") header := v.GetString("auth.header")
if err != nil {
return nil, err
}
if header == "" { if header == "" {
header = defaultAuther["header"].(string) header = defaultAuther["header"].(string)
} }
@ -108,20 +102,11 @@ func getNoAuth() auth.Auther {
return &auth.NoAuth{} return &auth.NoAuth{}
} }
func getJSONAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (auth.Auther, error) { func getJSONAuth(defaultAuther map[string]interface{}) (auth.Auther, error) {
jsonAuth := &auth.JSONAuth{} jsonAuth := &auth.JSONAuth{}
host, err := getString(flags, "recaptcha.host") host := v.GetString("recaptcha.host")
if err != nil { key := v.GetString("recaptcha.key")
return nil, err secret := v.GetString("recaptcha.secret")
}
key, err := getString(flags, "recaptcha.key")
if err != nil {
return nil, err
}
secret, err := getString(flags, "recaptcha.secret")
if err != nil {
return nil, err
}
if key == "" { if key == "" {
if kmap, ok := defaultAuther["recaptcha"].(map[string]interface{}); ok { if kmap, ok := defaultAuther["recaptcha"].(map[string]interface{}); ok {
@ -145,12 +130,8 @@ func getJSONAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (au
return jsonAuth, nil return jsonAuth, nil
} }
func getHookAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (auth.Auther, error) { func getHookAuth(defaultAuther map[string]interface{}) (auth.Auther, error) {
command, err := getString(flags, "auth.command") command := v.GetString("auth.command")
if err != nil {
return nil, err
}
if command == "" { if command == "" {
command = defaultAuther["command"].(string) command = defaultAuther["command"].(string)
} }
@ -162,8 +143,8 @@ func getHookAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (au
return &auth.HookAuth{Command: command}, nil return &auth.HookAuth{Command: command}, nil
} }
func getAuthentication(flags *pflag.FlagSet, defaults ...interface{}) (settings.AuthMethod, auth.Auther, error) { func getAuthentication(defaults ...interface{}) (settings.AuthMethod, auth.Auther, error) {
method, defaultAuther, err := getAuthMethod(flags, defaults...) method, defaultAuther, err := getAuthMethod(defaults...)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -171,13 +152,13 @@ func getAuthentication(flags *pflag.FlagSet, defaults ...interface{}) (settings.
var auther auth.Auther var auther auth.Auther
switch method { switch method {
case auth.MethodProxyAuth: case auth.MethodProxyAuth:
auther, err = getProxyAuth(flags, defaultAuther) auther, err = getProxyAuth(defaultAuther)
case auth.MethodNoAuth: case auth.MethodNoAuth:
auther = getNoAuth() auther = getNoAuth()
case auth.MethodJSONAuth: case auth.MethodJSONAuth:
auther, err = getJSONAuth(flags, defaultAuther) auther, err = getJSONAuth(defaultAuther)
case auth.MethodHookAuth: case auth.MethodHookAuth:
auther, err = getHookAuth(flags, defaultAuther) auther, err = getHookAuth(defaultAuther)
default: default:
return "", nil, errors.ErrInvalidAuthMethod return "", nil, errors.ErrInvalidAuthMethod
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/spf13/cobra" "github.com/spf13/cobra"
v "github.com/spf13/viper"
"github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/settings"
) )
@ -25,7 +26,7 @@ override the options.`,
RunE: python(func(cmd *cobra.Command, _ []string, d *pythonData) error { RunE: python(func(cmd *cobra.Command, _ []string, d *pythonData) error {
defaults := settings.UserDefaults{} defaults := settings.UserDefaults{}
flags := cmd.Flags() flags := cmd.Flags()
err := getUserDefaults(flags, &defaults, true) err := getUserDefaults(&defaults, true)
if err != nil { if err != nil {
return err return err
} }
@ -34,135 +35,43 @@ override the options.`,
return err return err
} }
key := generateKey()
signup, err := getBool(flags, "signup")
if err != nil {
return err
}
hideLoginButton, err := getBool(flags, "hide-login-button")
if err != nil {
return err
}
createUserDir, err := getBool(flags, "create-user-dir")
if err != nil {
return err
}
minLength, err := getUint(flags, "minimum-password-length")
if err != nil {
return err
}
shell, err := getString(flags, "shell")
if err != nil {
return err
}
brandingName, err := getString(flags, "branding.name")
if err != nil {
return err
}
brandingDisableExternal, err := getBool(flags, "branding.disableExternal")
if err != nil {
return err
}
brandingDisableUsedPercentage, err := getBool(flags, "branding.disableUsedPercentage")
if err != nil {
return err
}
brandingTheme, err := getString(flags, "branding.theme")
if err != nil {
return err
}
brandingFiles, err := getString(flags, "branding.files")
if err != nil {
return err
}
s := &settings.Settings{ s := &settings.Settings{
Key: key, Key: generateKey(),
Signup: signup, Signup: v.GetBool("signup"),
HideLoginButton: hideLoginButton, HideLoginButton: v.GetBool("hide-login-button"),
CreateUserDir: createUserDir, CreateUserDir: v.GetBool("create-user-dir"),
MinimumPasswordLength: minLength, MinimumPasswordLength: v.GetUint("minimum-password-length"),
Shell: convertCmdStrToCmdArray(shell), Shell: convertCmdStrToCmdArray(v.GetString("shell")),
AuthMethod: authMethod, AuthMethod: authMethod,
Defaults: defaults, Defaults: defaults,
Branding: settings.Branding{ Branding: settings.Branding{
Name: brandingName, Name: v.GetString("branding.name"),
DisableExternal: brandingDisableExternal, DisableExternal: v.GetBool("branding.disableexternal"),
DisableUsedPercentage: brandingDisableUsedPercentage, DisableUsedPercentage: v.GetBool("branding.disableusedpercentage"),
Theme: brandingTheme, Theme: v.GetString("branding.theme"),
Files: brandingFiles, Files: v.GetString("branding.files"),
}, },
} }
s.FileMode, err = getMode(flags, "file-mode") s.FileMode, err = getAndParseMode("file-mode")
if err != nil { if err != nil {
return err return err
} }
s.DirMode, err = getMode(flags, "dir-mode") s.DirMode, err = getAndParseMode("dir-mode")
if err != nil {
return err
}
address, err := getString(flags, "address")
if err != nil {
return err
}
socket, err := getString(flags, "socket")
if err != nil {
return err
}
root, err := getString(flags, "root")
if err != nil {
return err
}
baseURL, err := getString(flags, "baseurl")
if err != nil {
return err
}
tlsKey, err := getString(flags, "key")
if err != nil {
return err
}
cert, err := getString(flags, "cert")
if err != nil {
return err
}
port, err := getString(flags, "port")
if err != nil {
return err
}
log, err := getString(flags, "log")
if err != nil { if err != nil {
return err return err
} }
ser := &settings.Server{ ser := &settings.Server{
Address: address, Address: v.GetString("address"),
Socket: socket, Socket: v.GetString("socket"),
Root: root, Root: v.GetString("root"),
BaseURL: baseURL, BaseURL: v.GetString("baseurl"),
TLSKey: tlsKey, TLSKey: v.GetString("key"),
TLSCert: cert, TLSCert: v.GetString("cert"),
Port: port, Port: v.GetString("port"),
Log: log, Log: v.GetString("log"),
} }
err = d.store.Settings.Save(s) err = d.store.Settings.Save(s)

View File

@ -2,7 +2,7 @@ package cmd
import ( import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" v "github.com/spf13/viper"
) )
func init() { func init() {
@ -29,65 +29,67 @@ you want to change. Other options will remain unchanged.`,
} }
hasAuth := false hasAuth := false
flags.Visit(func(flag *pflag.Flag) {
if err != nil { for _, key := range v.AllKeys() {
return if !v.IsSet(key) {
continue
} }
switch flag.Name {
switch key {
case "baseurl": case "baseurl":
ser.BaseURL, err = getString(flags, flag.Name) ser.BaseURL = v.GetString(key)
case "root": case "root":
ser.Root, err = getString(flags, flag.Name) ser.Root = v.GetString(key)
case "socket": case "socket":
ser.Socket, err = getString(flags, flag.Name) ser.Socket = v.GetString(key)
case "cert": case "cert":
ser.TLSCert, err = getString(flags, flag.Name) ser.TLSCert = v.GetString(key)
case "key": case "key":
ser.TLSKey, err = getString(flags, flag.Name) ser.TLSKey = v.GetString(key)
case "address": case "address":
ser.Address, err = getString(flags, flag.Name) ser.Address = v.GetString(key)
case "port": case "port":
ser.Port, err = getString(flags, flag.Name) ser.Port = v.GetString(key)
case "log": case "log":
ser.Log, err = getString(flags, flag.Name) ser.Log = v.GetString(key)
case "hide-login-button": case "hide-login-button":
set.HideLoginButton, err = getBool(flags, flag.Name) set.HideLoginButton = v.GetBool(key)
case "signup": case "signup":
set.Signup, err = getBool(flags, flag.Name) set.Signup = v.GetBool(key)
case "auth.method": case "auth.method":
hasAuth = true hasAuth = true
case "shell": case "shell":
var shell string var shell string
shell, err = getString(flags, flag.Name) shell = v.GetString(key)
set.Shell = convertCmdStrToCmdArray(shell) set.Shell = convertCmdStrToCmdArray(shell)
case "create-user-dir": case "create-user-dir":
set.CreateUserDir, err = getBool(flags, flag.Name) set.CreateUserDir = v.GetBool(key)
case "minimum-password-length": case "minimum-password-length":
set.MinimumPasswordLength, err = getUint(flags, flag.Name) set.MinimumPasswordLength = v.GetUint(key)
case "branding.name": case "branding.name":
set.Branding.Name, err = getString(flags, flag.Name) set.Branding.Name = v.GetString(key)
case "branding.color": case "branding.color":
set.Branding.Color, err = getString(flags, flag.Name) set.Branding.Color = v.GetString(key)
case "branding.theme": case "branding.theme":
set.Branding.Theme, err = getString(flags, flag.Name) set.Branding.Theme = v.GetString(key)
case "branding.disableExternal": case "branding.disableexternal":
set.Branding.DisableExternal, err = getBool(flags, flag.Name) set.Branding.DisableExternal = v.GetBool(key)
case "branding.disableUsedPercentage": case "branding.disableusedpercentage":
set.Branding.DisableUsedPercentage, err = getBool(flags, flag.Name) set.Branding.DisableUsedPercentage = v.GetBool(key)
case "branding.files": case "branding.files":
set.Branding.Files, err = getString(flags, flag.Name) set.Branding.Files = v.GetString(key)
case "file-mode": case "file-mode":
set.FileMode, err = getMode(flags, flag.Name) set.FileMode, err = getAndParseMode(key)
case "dir-mode": case "dir-mode":
set.DirMode, err = getMode(flags, flag.Name) set.DirMode, err = getAndParseMode(key)
} }
})
if err != nil { if err != nil {
return err return err
}
} }
err = getUserDefaults(flags, &set.Defaults, false) err = getUserDefaults(&set.Defaults, false)
if err != nil { if err != nil {
return err return err
} }

View File

@ -40,7 +40,7 @@ var docsCmd = &cobra.Command{
Hidden: true, Hidden: true,
Args: cobra.NoArgs, Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
dir, err := getString(cmd.Flags(), "path") dir, err := cmd.Flags().GetString("path")
if err != nil { if err != nil {
return err return err
} }

View File

@ -119,7 +119,7 @@ user created with the credentials from options "username" and "password".`,
log.Println(cfgFile) log.Println(cfgFile)
if !d.hadDB { if !d.hadDB {
err := quickSetup(cmd.Flags(), *d) err := quickSetup(*d)
if err != nil { if err != nil {
return err return err
} }
@ -147,7 +147,7 @@ user created with the credentials from options "username" and "password".`,
fileCache = diskcache.New(afero.NewOsFs(), cacheDir) fileCache = diskcache.New(afero.NewOsFs(), cacheDir)
} }
server, err := getRunParams(cmd.Flags(), d.store) server, err := getRunParams(d.store)
if err != nil { if err != nil {
return err return err
} }
@ -256,48 +256,48 @@ user created with the credentials from options "username" and "password".`,
}, pythonConfig{allowNoDB: true}), }, pythonConfig{allowNoDB: true}),
} }
func getRunParams(flags *pflag.FlagSet, st *storage.Storage) (*settings.Server, error) { func getRunParams(st *storage.Storage) (*settings.Server, error) {
server, err := st.Settings.GetServer() server, err := st.Settings.GetServer()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if val, set := getStringParamB(flags, "root"); set { if val, set := getStringParamB("root"); set {
server.Root = val server.Root = val
} }
if val, set := getStringParamB(flags, "baseurl"); set { if val, set := getStringParamB("baseurl"); set {
server.BaseURL = val server.BaseURL = val
} }
if val, set := getStringParamB(flags, "log"); set { if val, set := getStringParamB("log"); set {
server.Log = val server.Log = val
} }
isSocketSet := false isSocketSet := false
isAddrSet := false isAddrSet := false
if val, set := getStringParamB(flags, "address"); set { if val, set := getStringParamB("address"); set {
server.Address = val server.Address = val
isAddrSet = isAddrSet || set isAddrSet = isAddrSet || set
} }
if val, set := getStringParamB(flags, "port"); set { if val, set := getStringParamB("port"); set {
server.Port = val server.Port = val
isAddrSet = isAddrSet || set isAddrSet = isAddrSet || set
} }
if val, set := getStringParamB(flags, "key"); set { if val, set := getStringParamB("key"); set {
server.TLSKey = val server.TLSKey = val
isAddrSet = isAddrSet || set isAddrSet = isAddrSet || set
} }
if val, set := getStringParamB(flags, "cert"); set { if val, set := getStringParamB("cert"); set {
server.TLSCert = val server.TLSCert = val
isAddrSet = isAddrSet || set isAddrSet = isAddrSet || set
} }
if val, set := getStringParamB(flags, "socket"); set { if val, set := getStringParamB("socket"); set {
server.Socket = val server.Socket = val
isSocketSet = isSocketSet || set isSocketSet = isSocketSet || set
} }
@ -311,16 +311,16 @@ func getRunParams(flags *pflag.FlagSet, st *storage.Storage) (*settings.Server,
server.Socket = "" server.Socket = ""
} }
disableThumbnails := getBoolParam(flags, "disable-thumbnails") disableThumbnails := v.GetBool("disable-thumbnails")
server.EnableThumbnails = !disableThumbnails server.EnableThumbnails = !disableThumbnails
disablePreviewResize := getBoolParam(flags, "disable-preview-resize") disablePreviewResize := v.GetBool("disable-preview-resize")
server.ResizePreview = !disablePreviewResize server.ResizePreview = !disablePreviewResize
disableTypeDetectionByHeader := getBoolParam(flags, "disable-type-detection-by-header") disableTypeDetectionByHeader := v.GetBool("disable-type-detection-by-header")
server.TypeDetectionByHeader = !disableTypeDetectionByHeader server.TypeDetectionByHeader = !disableTypeDetectionByHeader
disableExec := getBoolParam(flags, "disable-exec") disableExec := v.GetBool("disable-exec")
server.EnableExec = !disableExec server.EnableExec = !disableExec
if server.EnableExec { if server.EnableExec {
@ -330,69 +330,15 @@ func getRunParams(flags *pflag.FlagSet, st *storage.Storage) (*settings.Server,
log.Println("WARNING: read https://github.com/filebrowser/filebrowser/issues/5199") log.Println("WARNING: read https://github.com/filebrowser/filebrowser/issues/5199")
} }
if val, set := getStringParamB(flags, "token-expiration-time"); set { if val, set := getStringParamB("token-expiration-time"); set {
server.TokenExpirationTime = val server.TokenExpirationTime = val
} }
return server, nil return server, nil
} }
// getBoolParamB returns a parameter as a string and a boolean to tell if it is different from the default func getStringParamB(key string) (string, bool) {
// return v.GetString(key), v.IsSet(key)
// NOTE: we could simply bind the flags to viper and use IsSet.
// Although there is a bug on Viper that always returns true on IsSet
// if a flag is binded. Our alternative way is to manually check
// the flag and then the value from env/config/gotten by viper.
// https://github.com/spf13/viper/pull/331
func getBoolParamB(flags *pflag.FlagSet, key string) (value, ok bool) {
value, _ = flags.GetBool(key)
// If set on Flags, use it.
if flags.Changed(key) {
return value, true
}
// If set through viper (env, config), return it.
if v.IsSet(key) {
return v.GetBool(key), true
}
// Otherwise use default value on flags.
return value, false
}
func getBoolParam(flags *pflag.FlagSet, key string) bool {
val, _ := getBoolParamB(flags, key)
return val
}
// getStringParamB returns a parameter as a string and a boolean to tell if it is different from the default
//
// NOTE: we could simply bind the flags to viper and use IsSet.
// Although there is a bug on Viper that always returns true on IsSet
// if a flag is binded. Our alternative way is to manually check
// the flag and then the value from env/config/gotten by viper.
// https://github.com/spf13/viper/pull/331
func getStringParamB(flags *pflag.FlagSet, key string) (string, bool) {
value, _ := flags.GetString(key)
// If set on Flags, use it.
if flags.Changed(key) {
return value, true
}
// If set through viper (env, config), return it.
if v.IsSet(key) {
return v.GetString(key), true
}
// Otherwise use default value on flags.
return value, false
}
func getStringParam(flags *pflag.FlagSet, key string) string {
val, _ := getStringParamB(flags, key)
return val
} }
func setupLog(logMethod string) { func setupLog(logMethod string) {
@ -413,7 +359,7 @@ func setupLog(logMethod string) {
} }
} }
func quickSetup(flags *pflag.FlagSet, d pythonData) error { func quickSetup(d pythonData) error {
log.Println("Performing quick setup") log.Println("Performing quick setup")
set := &settings.Settings{ set := &settings.Settings{
@ -427,7 +373,7 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) error {
Scope: ".", Scope: ".",
Locale: "en", Locale: "en",
SingleClick: false, SingleClick: false,
AceEditorTheme: getStringParam(flags, "defaults.aceEditorTheme"), AceEditorTheme: v.GetString("defaults.aceeditortheme"),
Perm: users.Permissions{ Perm: users.Permissions{
Admin: false, Admin: false,
Execute: true, Execute: true,
@ -451,7 +397,7 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) error {
} }
var err error var err error
if _, noauth := getStringParamB(flags, "noauth"); noauth { if _, noauth := getStringParamB("noauth"); noauth {
set.AuthMethod = auth.MethodNoAuth set.AuthMethod = auth.MethodNoAuth
err = d.store.Auth.Save(&auth.NoAuth{}) err = d.store.Auth.Save(&auth.NoAuth{})
} else { } else {
@ -468,13 +414,13 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) error {
} }
ser := &settings.Server{ ser := &settings.Server{
BaseURL: getStringParam(flags, "baseurl"), BaseURL: v.GetString("baseurl"),
Port: getStringParam(flags, "port"), Port: v.GetString("port"),
Log: getStringParam(flags, "log"), Log: v.GetString("log"),
TLSKey: getStringParam(flags, "key"), TLSKey: v.GetString("key"),
TLSCert: getStringParam(flags, "cert"), TLSCert: v.GetString("cert"),
Address: getStringParam(flags, "address"), Address: v.GetString("address"),
Root: getStringParam(flags, "root"), Root: v.GetString("root"),
} }
err = d.store.Settings.SaveServer(ser) err = d.store.Settings.SaveServer(ser)
@ -482,8 +428,8 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) error {
return err return err
} }
username := getStringParam(flags, "username") username := v.GetString("username")
password := getStringParam(flags, "password") password := v.GetString("password")
if password == "" { if password == "" {
var pwd string var pwd string

View File

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" v "github.com/spf13/viper"
"github.com/filebrowser/filebrowser/v2/rules" "github.com/filebrowser/filebrowser/v2/rules"
"github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/settings"
@ -30,7 +30,7 @@ rules.`,
} }
func runRules(st *storage.Storage, cmd *cobra.Command, usersFn func(*users.User) error, globalFn func(*settings.Settings) error) error { func runRules(st *storage.Storage, cmd *cobra.Command, usersFn func(*users.User) error, globalFn func(*settings.Settings) error) error {
id, err := getUserIdentifier(cmd.Flags()) id, err := getUserIdentifier()
if err != nil { if err != nil {
return err return err
} }
@ -68,15 +68,9 @@ func runRules(st *storage.Storage, cmd *cobra.Command, usersFn func(*users.User)
return nil return nil
} }
func getUserIdentifier(flags *pflag.FlagSet) (interface{}, error) { func getUserIdentifier() (interface{}, error) {
id, err := getUint(flags, "id") id := v.GetUint("id")
if err != nil { username := v.GetString("username")
return nil, err
}
username, err := getString(flags, "username")
if err != nil {
return nil, err
}
if id != 0 { if id != 0 {
return id, nil return id, nil

View File

@ -4,6 +4,7 @@ import (
"regexp" "regexp"
"github.com/spf13/cobra" "github.com/spf13/cobra"
v "github.com/spf13/viper"
"github.com/filebrowser/filebrowser/v2/rules" "github.com/filebrowser/filebrowser/v2/rules"
"github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/settings"
@ -22,14 +23,8 @@ var rulesAddCmd = &cobra.Command{
Long: `Add a global rule or user rule.`, Long: `Add a global rule or user rule.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: python(func(cmd *cobra.Command, args []string, d *pythonData) error { RunE: python(func(cmd *cobra.Command, args []string, d *pythonData) error {
allow, err := getBool(cmd.Flags(), "allow") allow := v.GetBool("allow")
if err != nil { regex := v.GetBool("regex")
return err
}
regex, err := getBool(cmd.Flags(), "regex")
if err != nil {
return err
}
exp := args[0] exp := args[0]
if regex { if regex {

View File

@ -9,6 +9,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" "github.com/spf13/pflag"
v "github.com/spf13/viper"
"github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/settings"
"github.com/filebrowser/filebrowser/v2/users" "github.com/filebrowser/filebrowser/v2/users"
@ -82,11 +83,8 @@ func addUserFlags(flags *pflag.FlagSet) {
flags.String("aceEditorTheme", "", "ace editor's syntax highlighting theme for users") flags.String("aceEditorTheme", "", "ace editor's syntax highlighting theme for users")
} }
func getViewMode(flags *pflag.FlagSet) (users.ViewMode, error) { func getViewMode() (users.ViewMode, error) {
viewModeStr, err := getString(flags, "viewMode") viewModeStr := v.GetString("viewmode")
if err != nil {
return "", err
}
viewMode := users.ViewMode(viewModeStr) viewMode := users.ViewMode(viewModeStr)
if viewMode != users.ListViewMode && viewMode != users.MosaicViewMode { if viewMode != users.ListViewMode && viewMode != users.MosaicViewMode {
return "", errors.New("view mode must be \"" + string(users.ListViewMode) + "\" or \"" + string(users.MosaicViewMode) + "\"") return "", errors.New("view mode must be \"" + string(users.ListViewMode) + "\" or \"" + string(users.MosaicViewMode) + "\"")
@ -94,58 +92,55 @@ func getViewMode(flags *pflag.FlagSet) (users.ViewMode, error) {
return viewMode, nil return viewMode, nil
} }
func getUserDefaults(flags *pflag.FlagSet, defaults *settings.UserDefaults, all bool) error { func getUserDefaults(defaults *settings.UserDefaults, all bool) error {
var visitErr error keys := v.AllKeys()
visit := func(flag *pflag.Flag) {
if visitErr != nil { for _, key := range keys {
return if !all && !v.IsSet(key) {
continue
} }
var err error var err error
switch flag.Name { switch key {
case "scope": case "scope":
defaults.Scope, err = getString(flags, flag.Name) defaults.Scope = v.GetString(key)
case "locale": case "locale":
defaults.Locale, err = getString(flags, flag.Name) defaults.Locale = v.GetString(key)
case "viewMode": case "viewmode":
defaults.ViewMode, err = getViewMode(flags) defaults.ViewMode, err = getViewMode()
case "singleClick": case "singleclick":
defaults.SingleClick, err = getBool(flags, flag.Name) defaults.SingleClick = v.GetBool(key)
case "aceEditorTheme": case "aceeditortheme":
defaults.AceEditorTheme, err = getString(flags, flag.Name) defaults.AceEditorTheme = v.GetString(key)
case "perm.admin": case "perm.admin":
defaults.Perm.Admin, err = getBool(flags, flag.Name) defaults.Perm.Admin = v.GetBool(key)
case "perm.execute": case "perm.execute":
defaults.Perm.Execute, err = getBool(flags, flag.Name) defaults.Perm.Execute = v.GetBool(key)
case "perm.create": case "perm.create":
defaults.Perm.Create, err = getBool(flags, flag.Name) defaults.Perm.Create = v.GetBool(key)
case "perm.rename": case "perm.rename":
defaults.Perm.Rename, err = getBool(flags, flag.Name) defaults.Perm.Rename = v.GetBool(key)
case "perm.modify": case "perm.modify":
defaults.Perm.Modify, err = getBool(flags, flag.Name) defaults.Perm.Modify = v.GetBool(key)
case "perm.delete": case "perm.delete":
defaults.Perm.Delete, err = getBool(flags, flag.Name) defaults.Perm.Delete = v.GetBool(key)
case "perm.share": case "perm.share":
defaults.Perm.Share, err = getBool(flags, flag.Name) defaults.Perm.Share = v.GetBool(key)
case "perm.download": case "perm.download":
defaults.Perm.Download, err = getBool(flags, flag.Name) defaults.Perm.Download = v.GetBool(key)
case "commands": case "commands":
defaults.Commands, err = flags.GetStringSlice(flag.Name) defaults.Commands = v.GetStringSlice(key)
case "sorting.by": case "sorting.by":
defaults.Sorting.By, err = getString(flags, flag.Name) defaults.Sorting.By = v.GetString(key)
case "sorting.asc": case "sorting.asc":
defaults.Sorting.Asc, err = getBool(flags, flag.Name) defaults.Sorting.Asc = v.GetBool(key)
case "hideDotfiles": case "hidedotfiles":
defaults.HideDotfiles, err = getBool(flags, flag.Name) defaults.HideDotfiles = v.GetBool(key)
} }
if err != nil { if err != nil {
visitErr = err return err
} }
} }
if all { return nil
flags.VisitAll(visit)
} else {
flags.Visit(visit)
}
return visitErr
} }

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
v "github.com/spf13/viper"
"github.com/filebrowser/filebrowser/v2/users" "github.com/filebrowser/filebrowser/v2/users"
) )
@ -21,7 +22,7 @@ var usersAddCmd = &cobra.Command{
if err != nil { if err != nil {
return err return err
} }
err = getUserDefaults(cmd.Flags(), &s.Defaults, false) err = getUserDefaults(&s.Defaults, false)
if err != nil { if err != nil {
return err return err
} }
@ -31,27 +32,12 @@ var usersAddCmd = &cobra.Command{
return err return err
} }
lockPassword, err := getBool(cmd.Flags(), "lockPassword")
if err != nil {
return err
}
dateFormat, err := getBool(cmd.Flags(), "dateFormat")
if err != nil {
return err
}
hideDotfiles, err := getBool(cmd.Flags(), "hideDotfiles")
if err != nil {
return err
}
user := &users.User{ user := &users.User{
Username: args[0], Username: args[0],
Password: password, Password: password,
LockPassword: lockPassword, LockPassword: v.GetBool("lockpassword"),
DateFormat: dateFormat, DateFormat: v.GetBool("dateformat"),
HideDotfiles: hideDotfiles, HideDotfiles: v.GetBool("hidedotfiles"),
} }
s.Defaults.Apply(user) s.Defaults.Apply(user)

View File

@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"github.com/spf13/cobra" "github.com/spf13/cobra"
v "github.com/spf13/viper"
"github.com/filebrowser/filebrowser/v2/users" "github.com/filebrowser/filebrowser/v2/users"
) )
@ -45,12 +46,7 @@ list or set it to 0.`,
} }
} }
replace, err := getBool(cmd.Flags(), "replace") if v.GetBool("replace") {
if err != nil {
return err
}
if replace {
oldUsers, userImportErr := d.store.Users.Gets("") oldUsers, userImportErr := d.store.Users.Gets("")
if userImportErr != nil { if userImportErr != nil {
return userImportErr return userImportErr
@ -69,10 +65,7 @@ list or set it to 0.`,
} }
} }
overwrite, err := getBool(cmd.Flags(), "overwrite") overwrite := v.GetBool("overwrite")
if err != nil {
return err
}
for _, user := range list { for _, user := range list {
onDB, err := d.store.Users.Get("", user.ID) onDB, err := d.store.Users.Get("", user.ID)

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
v "github.com/spf13/viper"
"github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/settings"
"github.com/filebrowser/filebrowser/v2/users" "github.com/filebrowser/filebrowser/v2/users"
@ -23,15 +24,8 @@ options you want to change.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: python(func(cmd *cobra.Command, args []string, d *pythonData) error { RunE: python(func(cmd *cobra.Command, args []string, d *pythonData) error {
username, id := parseUsernameOrID(args[0]) username, id := parseUsernameOrID(args[0])
flags := cmd.Flags() password := v.GetString("password")
password, err := getString(flags, "password") newUsername := v.GetString("username")
if err != nil {
return err
}
newUsername, err := getString(flags, "username")
if err != nil {
return err
}
s, err := d.store.Settings.Get() s, err := d.store.Settings.Get()
if err != nil { if err != nil {
@ -61,7 +55,7 @@ options you want to change.`,
Sorting: user.Sorting, Sorting: user.Sorting,
Commands: user.Commands, Commands: user.Commands,
} }
err = getUserDefaults(flags, &defaults, false) err = getUserDefaults(&defaults, false)
if err != nil { if err != nil {
return err return err
} }
@ -72,18 +66,9 @@ options you want to change.`,
user.Perm = defaults.Perm user.Perm = defaults.Perm
user.Commands = defaults.Commands user.Commands = defaults.Commands
user.Sorting = defaults.Sorting user.Sorting = defaults.Sorting
user.LockPassword, err = getBool(flags, "lockPassword") user.LockPassword = v.GetBool("lockpassword")
if err != nil { user.DateFormat = v.GetBool("dateformat")
return err user.HideDotfiles = v.GetBool("hidedotfiles")
}
user.DateFormat, err = getBool(flags, "dateFormat")
if err != nil {
return err
}
user.HideDotfiles, err = getBool(flags, "hideDotfiles")
if err != nil {
return err
}
if newUsername != "" { if newUsername != "" {
user.Username = newUsername user.Username = newUsername

View File

@ -13,7 +13,7 @@ import (
"github.com/asdine/storm/v3" "github.com/asdine/storm/v3"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" v "github.com/spf13/viper"
yaml "gopkg.in/yaml.v3" yaml "gopkg.in/yaml.v3"
"github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/settings"
@ -23,15 +23,8 @@ import (
const dbPerms = 0640 const dbPerms = 0640
func getString(flags *pflag.FlagSet, flag string) (string, error) { func getAndParseMode(param string) (fs.FileMode, error) {
return flags.GetString(flag) s := v.GetString(param)
}
func getMode(flags *pflag.FlagSet, flag string) (fs.FileMode, error) {
s, err := getString(flags, flag)
if err != nil {
return 0, err
}
b, err := strconv.ParseUint(s, 0, 32) b, err := strconv.ParseUint(s, 0, 32)
if err != nil { if err != nil {
return 0, err return 0, err
@ -39,14 +32,6 @@ func getMode(flags *pflag.FlagSet, flag string) (fs.FileMode, error) {
return fs.FileMode(b), nil return fs.FileMode(b), nil
} }
func getBool(flags *pflag.FlagSet, flag string) (bool, error) {
return flags.GetBool(flag)
}
func getUint(flags *pflag.FlagSet, flag string) (uint, error) {
return flags.GetUint(flag)
}
func generateKey() []byte { func generateKey() []byte {
k, err := settings.GenerateKey() k, err := settings.GenerateKey()
if err != nil { if err != nil {
@ -91,9 +76,14 @@ func dbExists(path string) (bool, error) {
func python(fn pythonFunc, cfg pythonConfig) cobraFunc { func python(fn pythonFunc, cfg pythonConfig) cobraFunc {
return func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error {
err := v.BindPFlags(cmd.Flags())
if err != nil {
panic(err)
}
data := &pythonData{hadDB: true} data := &pythonData{hadDB: true}
path := getStringParam(cmd.Flags(), "database") path := v.GetString("database")
absPath, err := filepath.Abs(path) absPath, err := filepath.Abs(path)
if err != nil { if err != nil {
panic(err) panic(err)