diff --git a/cmd/cmds_ls.go b/cmd/cmds_ls.go index fa901a56..fe1459e6 100644 --- a/cmd/cmds_ls.go +++ b/cmd/cmds_ls.go @@ -2,6 +2,7 @@ package cmd import ( "github.com/spf13/cobra" + v "github.com/spf13/viper" ) func init() { @@ -19,11 +20,7 @@ var cmdsLsCmd = &cobra.Command{ if err != nil { return err } - evt, err := getString(cmd.Flags(), "event") - if err != nil { - return err - } - + evt := v.GetString("event") if evt == "" { printEvents(s.Commands) } else { diff --git a/cmd/config.go b/cmd/config.go index 84474f4c..30d06413 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -10,6 +10,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" + v "github.com/spf13/viper" "github.com/filebrowser/filebrowser/v2/auth" "github.com/filebrowser/filebrowser/v2/errors" @@ -56,11 +57,8 @@ func addConfigFlags(flags *pflag.FlagSet) { flags.String("dir-mode", fmt.Sprintf("%O", settings.DefaultDirMode), "Mode bits that new directories are created with") } -func getAuthMethod(flags *pflag.FlagSet, defaults ...interface{}) (settings.AuthMethod, map[string]interface{}, error) { - methodStr, err := getString(flags, "auth.method") - if err != nil { - return "", nil, err - } +func getAuthMethod(defaults ...interface{}) (settings.AuthMethod, map[string]interface{}, error) { + methodStr := v.GetString("auth.method") method := settings.AuthMethod(methodStr) var defaultAuther map[string]interface{} @@ -87,12 +85,8 @@ func getAuthMethod(flags *pflag.FlagSet, defaults ...interface{}) (settings.Auth return method, defaultAuther, nil } -func getProxyAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (auth.Auther, error) { - header, err := getString(flags, "auth.header") - if err != nil { - return nil, err - } - +func getProxyAuth(defaultAuther map[string]interface{}) (auth.Auther, error) { + header := v.GetString("auth.header") if header == "" { header = defaultAuther["header"].(string) } @@ -108,20 +102,11 @@ func getNoAuth() auth.Auther { return &auth.NoAuth{} } -func getJSONAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (auth.Auther, error) { +func getJSONAuth(defaultAuther map[string]interface{}) (auth.Auther, error) { jsonAuth := &auth.JSONAuth{} - host, err := getString(flags, "recaptcha.host") - if err != nil { - return nil, err - } - key, err := getString(flags, "recaptcha.key") - if err != nil { - return nil, err - } - secret, err := getString(flags, "recaptcha.secret") - if err != nil { - return nil, err - } + host := v.GetString("recaptcha.host") + key := v.GetString("recaptcha.key") + secret := v.GetString("recaptcha.secret") if key == "" { if kmap, ok := defaultAuther["recaptcha"].(map[string]interface{}); ok { @@ -145,12 +130,8 @@ func getJSONAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (au return jsonAuth, nil } -func getHookAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (auth.Auther, error) { - command, err := getString(flags, "auth.command") - if err != nil { - return nil, err - } - +func getHookAuth(defaultAuther map[string]interface{}) (auth.Auther, error) { + command := v.GetString("auth.command") if command == "" { command = defaultAuther["command"].(string) } @@ -162,8 +143,8 @@ func getHookAuth(flags *pflag.FlagSet, defaultAuther map[string]interface{}) (au return &auth.HookAuth{Command: command}, nil } -func getAuthentication(flags *pflag.FlagSet, defaults ...interface{}) (settings.AuthMethod, auth.Auther, error) { - method, defaultAuther, err := getAuthMethod(flags, defaults...) +func getAuthentication(defaults ...interface{}) (settings.AuthMethod, auth.Auther, error) { + method, defaultAuther, err := getAuthMethod(defaults...) if err != nil { return "", nil, err } @@ -171,13 +152,13 @@ func getAuthentication(flags *pflag.FlagSet, defaults ...interface{}) (settings. var auther auth.Auther switch method { case auth.MethodProxyAuth: - auther, err = getProxyAuth(flags, defaultAuther) + auther, err = getProxyAuth(defaultAuther) case auth.MethodNoAuth: auther = getNoAuth() case auth.MethodJSONAuth: - auther, err = getJSONAuth(flags, defaultAuther) + auther, err = getJSONAuth(defaultAuther) case auth.MethodHookAuth: - auther, err = getHookAuth(flags, defaultAuther) + auther, err = getHookAuth(defaultAuther) default: return "", nil, errors.ErrInvalidAuthMethod } diff --git a/cmd/config_init.go b/cmd/config_init.go index 693b6ace..fa198c21 100644 --- a/cmd/config_init.go +++ b/cmd/config_init.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/spf13/cobra" + v "github.com/spf13/viper" "github.com/filebrowser/filebrowser/v2/settings" ) @@ -25,7 +26,7 @@ override the options.`, RunE: python(func(cmd *cobra.Command, _ []string, d *pythonData) error { defaults := settings.UserDefaults{} flags := cmd.Flags() - err := getUserDefaults(flags, &defaults, true) + err := getUserDefaults(&defaults, true) if err != nil { return err } @@ -34,135 +35,43 @@ override the options.`, return err } - key := generateKey() - - signup, err := getBool(flags, "signup") - if err != nil { - return err - } - - hideLoginButton, err := getBool(flags, "hide-login-button") - if err != nil { - return err - } - - createUserDir, err := getBool(flags, "create-user-dir") - if err != nil { - return err - } - - minLength, err := getUint(flags, "minimum-password-length") - if err != nil { - return err - } - - shell, err := getString(flags, "shell") - if err != nil { - return err - } - - brandingName, err := getString(flags, "branding.name") - if err != nil { - return err - } - - brandingDisableExternal, err := getBool(flags, "branding.disableExternal") - if err != nil { - return err - } - - brandingDisableUsedPercentage, err := getBool(flags, "branding.disableUsedPercentage") - if err != nil { - return err - } - - brandingTheme, err := getString(flags, "branding.theme") - if err != nil { - return err - } - - brandingFiles, err := getString(flags, "branding.files") - if err != nil { - return err - } - s := &settings.Settings{ - Key: key, - Signup: signup, - HideLoginButton: hideLoginButton, - CreateUserDir: createUserDir, - MinimumPasswordLength: minLength, - Shell: convertCmdStrToCmdArray(shell), + Key: generateKey(), + Signup: v.GetBool("signup"), + HideLoginButton: v.GetBool("hide-login-button"), + CreateUserDir: v.GetBool("create-user-dir"), + MinimumPasswordLength: v.GetUint("minimum-password-length"), + Shell: convertCmdStrToCmdArray(v.GetString("shell")), AuthMethod: authMethod, Defaults: defaults, Branding: settings.Branding{ - Name: brandingName, - DisableExternal: brandingDisableExternal, - DisableUsedPercentage: brandingDisableUsedPercentage, - Theme: brandingTheme, - Files: brandingFiles, + Name: v.GetString("branding.name"), + DisableExternal: v.GetBool("branding.disableexternal"), + DisableUsedPercentage: v.GetBool("branding.disableusedpercentage"), + Theme: v.GetString("branding.theme"), + Files: v.GetString("branding.files"), }, } - s.FileMode, err = getMode(flags, "file-mode") + s.FileMode, err = getAndParseMode("file-mode") if err != nil { return err } - s.DirMode, err = getMode(flags, "dir-mode") - if err != nil { - return err - } - - address, err := getString(flags, "address") - if err != nil { - return err - } - - socket, err := getString(flags, "socket") - if err != nil { - return err - } - - root, err := getString(flags, "root") - if err != nil { - return err - } - - baseURL, err := getString(flags, "baseurl") - if err != nil { - return err - } - - tlsKey, err := getString(flags, "key") - if err != nil { - return err - } - - cert, err := getString(flags, "cert") - if err != nil { - return err - } - - port, err := getString(flags, "port") - if err != nil { - return err - } - - log, err := getString(flags, "log") + s.DirMode, err = getAndParseMode("dir-mode") if err != nil { return err } ser := &settings.Server{ - Address: address, - Socket: socket, - Root: root, - BaseURL: baseURL, - TLSKey: tlsKey, - TLSCert: cert, - Port: port, - Log: log, + Address: v.GetString("address"), + Socket: v.GetString("socket"), + Root: v.GetString("root"), + BaseURL: v.GetString("baseurl"), + TLSKey: v.GetString("key"), + TLSCert: v.GetString("cert"), + Port: v.GetString("port"), + Log: v.GetString("log"), } err = d.store.Settings.Save(s) diff --git a/cmd/config_set.go b/cmd/config_set.go index 255ef470..84ad8d02 100644 --- a/cmd/config_set.go +++ b/cmd/config_set.go @@ -2,7 +2,7 @@ package cmd import ( "github.com/spf13/cobra" - "github.com/spf13/pflag" + v "github.com/spf13/viper" ) func init() { @@ -29,65 +29,67 @@ you want to change. Other options will remain unchanged.`, } hasAuth := false - flags.Visit(func(flag *pflag.Flag) { - if err != nil { - return + + for _, key := range v.AllKeys() { + if !v.IsSet(key) { + continue } - switch flag.Name { + + switch key { case "baseurl": - ser.BaseURL, err = getString(flags, flag.Name) + ser.BaseURL = v.GetString(key) case "root": - ser.Root, err = getString(flags, flag.Name) + ser.Root = v.GetString(key) case "socket": - ser.Socket, err = getString(flags, flag.Name) + ser.Socket = v.GetString(key) case "cert": - ser.TLSCert, err = getString(flags, flag.Name) + ser.TLSCert = v.GetString(key) case "key": - ser.TLSKey, err = getString(flags, flag.Name) + ser.TLSKey = v.GetString(key) case "address": - ser.Address, err = getString(flags, flag.Name) + ser.Address = v.GetString(key) case "port": - ser.Port, err = getString(flags, flag.Name) + ser.Port = v.GetString(key) case "log": - ser.Log, err = getString(flags, flag.Name) + ser.Log = v.GetString(key) case "hide-login-button": - set.HideLoginButton, err = getBool(flags, flag.Name) + set.HideLoginButton = v.GetBool(key) case "signup": - set.Signup, err = getBool(flags, flag.Name) + set.Signup = v.GetBool(key) case "auth.method": hasAuth = true case "shell": var shell string - shell, err = getString(flags, flag.Name) + shell = v.GetString(key) set.Shell = convertCmdStrToCmdArray(shell) case "create-user-dir": - set.CreateUserDir, err = getBool(flags, flag.Name) + set.CreateUserDir = v.GetBool(key) case "minimum-password-length": - set.MinimumPasswordLength, err = getUint(flags, flag.Name) + set.MinimumPasswordLength = v.GetUint(key) case "branding.name": - set.Branding.Name, err = getString(flags, flag.Name) + set.Branding.Name = v.GetString(key) case "branding.color": - set.Branding.Color, err = getString(flags, flag.Name) + set.Branding.Color = v.GetString(key) case "branding.theme": - set.Branding.Theme, err = getString(flags, flag.Name) - case "branding.disableExternal": - set.Branding.DisableExternal, err = getBool(flags, flag.Name) - case "branding.disableUsedPercentage": - set.Branding.DisableUsedPercentage, err = getBool(flags, flag.Name) + set.Branding.Theme = v.GetString(key) + case "branding.disableexternal": + set.Branding.DisableExternal = v.GetBool(key) + case "branding.disableusedpercentage": + set.Branding.DisableUsedPercentage = v.GetBool(key) case "branding.files": - set.Branding.Files, err = getString(flags, flag.Name) + set.Branding.Files = v.GetString(key) case "file-mode": - set.FileMode, err = getMode(flags, flag.Name) + set.FileMode, err = getAndParseMode(key) case "dir-mode": - set.DirMode, err = getMode(flags, flag.Name) + set.DirMode, err = getAndParseMode(key) } - }) - if err != nil { - return err + if err != nil { + return err + } } - err = getUserDefaults(flags, &set.Defaults, false) + err = getUserDefaults(&set.Defaults, false) if err != nil { return err } diff --git a/cmd/docs.go b/cmd/docs.go index 90e5a259..763f4322 100644 --- a/cmd/docs.go +++ b/cmd/docs.go @@ -40,7 +40,7 @@ var docsCmd = &cobra.Command{ Hidden: true, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, _ []string) error { - dir, err := getString(cmd.Flags(), "path") + dir, err := cmd.Flags().GetString("path") if err != nil { return err } diff --git a/cmd/root.go b/cmd/root.go index 24f5d077..a8209805 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -119,7 +119,7 @@ user created with the credentials from options "username" and "password".`, log.Println(cfgFile) if !d.hadDB { - err := quickSetup(cmd.Flags(), *d) + err := quickSetup(*d) if err != nil { return err } @@ -147,7 +147,7 @@ user created with the credentials from options "username" and "password".`, fileCache = diskcache.New(afero.NewOsFs(), cacheDir) } - server, err := getRunParams(cmd.Flags(), d.store) + server, err := getRunParams(d.store) if err != nil { return err } @@ -256,48 +256,48 @@ user created with the credentials from options "username" and "password".`, }, pythonConfig{allowNoDB: true}), } -func getRunParams(flags *pflag.FlagSet, st *storage.Storage) (*settings.Server, error) { +func getRunParams(st *storage.Storage) (*settings.Server, error) { server, err := st.Settings.GetServer() if err != nil { return nil, err } - if val, set := getStringParamB(flags, "root"); set { + if val, set := getStringParamB("root"); set { server.Root = val } - if val, set := getStringParamB(flags, "baseurl"); set { + if val, set := getStringParamB("baseurl"); set { server.BaseURL = val } - if val, set := getStringParamB(flags, "log"); set { + if val, set := getStringParamB("log"); set { server.Log = val } isSocketSet := false isAddrSet := false - if val, set := getStringParamB(flags, "address"); set { + if val, set := getStringParamB("address"); set { server.Address = val isAddrSet = isAddrSet || set } - if val, set := getStringParamB(flags, "port"); set { + if val, set := getStringParamB("port"); set { server.Port = val isAddrSet = isAddrSet || set } - if val, set := getStringParamB(flags, "key"); set { + if val, set := getStringParamB("key"); set { server.TLSKey = val isAddrSet = isAddrSet || set } - if val, set := getStringParamB(flags, "cert"); set { + if val, set := getStringParamB("cert"); set { server.TLSCert = val isAddrSet = isAddrSet || set } - if val, set := getStringParamB(flags, "socket"); set { + if val, set := getStringParamB("socket"); set { server.Socket = val isSocketSet = isSocketSet || set } @@ -311,16 +311,16 @@ func getRunParams(flags *pflag.FlagSet, st *storage.Storage) (*settings.Server, server.Socket = "" } - disableThumbnails := getBoolParam(flags, "disable-thumbnails") + disableThumbnails := v.GetBool("disable-thumbnails") server.EnableThumbnails = !disableThumbnails - disablePreviewResize := getBoolParam(flags, "disable-preview-resize") + disablePreviewResize := v.GetBool("disable-preview-resize") server.ResizePreview = !disablePreviewResize - disableTypeDetectionByHeader := getBoolParam(flags, "disable-type-detection-by-header") + disableTypeDetectionByHeader := v.GetBool("disable-type-detection-by-header") server.TypeDetectionByHeader = !disableTypeDetectionByHeader - disableExec := getBoolParam(flags, "disable-exec") + disableExec := v.GetBool("disable-exec") server.EnableExec = !disableExec if server.EnableExec { @@ -330,69 +330,15 @@ func getRunParams(flags *pflag.FlagSet, st *storage.Storage) (*settings.Server, log.Println("WARNING: read https://github.com/filebrowser/filebrowser/issues/5199") } - if val, set := getStringParamB(flags, "token-expiration-time"); set { + if val, set := getStringParamB("token-expiration-time"); set { server.TokenExpirationTime = val } return server, nil } -// getBoolParamB returns a parameter as a string and a boolean to tell if it is different from the default -// -// NOTE: we could simply bind the flags to viper and use IsSet. -// Although there is a bug on Viper that always returns true on IsSet -// if a flag is binded. Our alternative way is to manually check -// the flag and then the value from env/config/gotten by viper. -// https://github.com/spf13/viper/pull/331 -func getBoolParamB(flags *pflag.FlagSet, key string) (value, ok bool) { - value, _ = flags.GetBool(key) - - // If set on Flags, use it. - if flags.Changed(key) { - return value, true - } - - // If set through viper (env, config), return it. - if v.IsSet(key) { - return v.GetBool(key), true - } - - // Otherwise use default value on flags. - return value, false -} - -func getBoolParam(flags *pflag.FlagSet, key string) bool { - val, _ := getBoolParamB(flags, key) - return val -} - -// getStringParamB returns a parameter as a string and a boolean to tell if it is different from the default -// -// NOTE: we could simply bind the flags to viper and use IsSet. -// Although there is a bug on Viper that always returns true on IsSet -// if a flag is binded. Our alternative way is to manually check -// the flag and then the value from env/config/gotten by viper. -// https://github.com/spf13/viper/pull/331 -func getStringParamB(flags *pflag.FlagSet, key string) (string, bool) { - value, _ := flags.GetString(key) - - // If set on Flags, use it. - if flags.Changed(key) { - return value, true - } - - // If set through viper (env, config), return it. - if v.IsSet(key) { - return v.GetString(key), true - } - - // Otherwise use default value on flags. - return value, false -} - -func getStringParam(flags *pflag.FlagSet, key string) string { - val, _ := getStringParamB(flags, key) - return val +func getStringParamB(key string) (string, bool) { + return v.GetString(key), v.IsSet(key) } func setupLog(logMethod string) { @@ -413,7 +359,7 @@ func setupLog(logMethod string) { } } -func quickSetup(flags *pflag.FlagSet, d pythonData) error { +func quickSetup(d pythonData) error { log.Println("Performing quick setup") set := &settings.Settings{ @@ -427,7 +373,7 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) error { Scope: ".", Locale: "en", SingleClick: false, - AceEditorTheme: getStringParam(flags, "defaults.aceEditorTheme"), + AceEditorTheme: v.GetString("defaults.aceeditortheme"), Perm: users.Permissions{ Admin: false, Execute: true, @@ -451,7 +397,7 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) error { } var err error - if _, noauth := getStringParamB(flags, "noauth"); noauth { + if _, noauth := getStringParamB("noauth"); noauth { set.AuthMethod = auth.MethodNoAuth err = d.store.Auth.Save(&auth.NoAuth{}) } else { @@ -468,13 +414,13 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) error { } ser := &settings.Server{ - BaseURL: getStringParam(flags, "baseurl"), - Port: getStringParam(flags, "port"), - Log: getStringParam(flags, "log"), - TLSKey: getStringParam(flags, "key"), - TLSCert: getStringParam(flags, "cert"), - Address: getStringParam(flags, "address"), - Root: getStringParam(flags, "root"), + BaseURL: v.GetString("baseurl"), + Port: v.GetString("port"), + Log: v.GetString("log"), + TLSKey: v.GetString("key"), + TLSCert: v.GetString("cert"), + Address: v.GetString("address"), + Root: v.GetString("root"), } err = d.store.Settings.SaveServer(ser) @@ -482,8 +428,8 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) error { return err } - username := getStringParam(flags, "username") - password := getStringParam(flags, "password") + username := v.GetString("username") + password := v.GetString("password") if password == "" { var pwd string diff --git a/cmd/rules.go b/cmd/rules.go index ffa5b1ae..23bdf673 100644 --- a/cmd/rules.go +++ b/cmd/rules.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/spf13/cobra" - "github.com/spf13/pflag" + v "github.com/spf13/viper" "github.com/filebrowser/filebrowser/v2/rules" "github.com/filebrowser/filebrowser/v2/settings" @@ -30,7 +30,7 @@ rules.`, } func runRules(st *storage.Storage, cmd *cobra.Command, usersFn func(*users.User) error, globalFn func(*settings.Settings) error) error { - id, err := getUserIdentifier(cmd.Flags()) + id, err := getUserIdentifier() if err != nil { return err } @@ -68,15 +68,9 @@ func runRules(st *storage.Storage, cmd *cobra.Command, usersFn func(*users.User) return nil } -func getUserIdentifier(flags *pflag.FlagSet) (interface{}, error) { - id, err := getUint(flags, "id") - if err != nil { - return nil, err - } - username, err := getString(flags, "username") - if err != nil { - return nil, err - } +func getUserIdentifier() (interface{}, error) { + id := v.GetUint("id") + username := v.GetString("username") if id != 0 { return id, nil diff --git a/cmd/rules_add.go b/cmd/rules_add.go index 9d1f0cf9..7f57856b 100644 --- a/cmd/rules_add.go +++ b/cmd/rules_add.go @@ -4,6 +4,7 @@ import ( "regexp" "github.com/spf13/cobra" + v "github.com/spf13/viper" "github.com/filebrowser/filebrowser/v2/rules" "github.com/filebrowser/filebrowser/v2/settings" @@ -22,14 +23,8 @@ var rulesAddCmd = &cobra.Command{ Long: `Add a global rule or user rule.`, Args: cobra.ExactArgs(1), RunE: python(func(cmd *cobra.Command, args []string, d *pythonData) error { - allow, err := getBool(cmd.Flags(), "allow") - if err != nil { - return err - } - regex, err := getBool(cmd.Flags(), "regex") - if err != nil { - return err - } + allow := v.GetBool("allow") + regex := v.GetBool("regex") exp := args[0] if regex { diff --git a/cmd/users.go b/cmd/users.go index c2e2ce1e..b37fe33e 100644 --- a/cmd/users.go +++ b/cmd/users.go @@ -9,6 +9,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" + v "github.com/spf13/viper" "github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/users" @@ -82,11 +83,8 @@ func addUserFlags(flags *pflag.FlagSet) { flags.String("aceEditorTheme", "", "ace editor's syntax highlighting theme for users") } -func getViewMode(flags *pflag.FlagSet) (users.ViewMode, error) { - viewModeStr, err := getString(flags, "viewMode") - if err != nil { - return "", err - } +func getViewMode() (users.ViewMode, error) { + viewModeStr := v.GetString("viewmode") viewMode := users.ViewMode(viewModeStr) if viewMode != users.ListViewMode && viewMode != users.MosaicViewMode { return "", errors.New("view mode must be \"" + string(users.ListViewMode) + "\" or \"" + string(users.MosaicViewMode) + "\"") @@ -94,58 +92,55 @@ func getViewMode(flags *pflag.FlagSet) (users.ViewMode, error) { return viewMode, nil } -func getUserDefaults(flags *pflag.FlagSet, defaults *settings.UserDefaults, all bool) error { - var visitErr error - visit := func(flag *pflag.Flag) { - if visitErr != nil { - return +func getUserDefaults(defaults *settings.UserDefaults, all bool) error { + keys := v.AllKeys() + + for _, key := range keys { + if !all && !v.IsSet(key) { + continue } + var err error - switch flag.Name { + switch key { case "scope": - defaults.Scope, err = getString(flags, flag.Name) + defaults.Scope = v.GetString(key) case "locale": - defaults.Locale, err = getString(flags, flag.Name) - case "viewMode": - defaults.ViewMode, err = getViewMode(flags) - case "singleClick": - defaults.SingleClick, err = getBool(flags, flag.Name) - case "aceEditorTheme": - defaults.AceEditorTheme, err = getString(flags, flag.Name) + defaults.Locale = v.GetString(key) + case "viewmode": + defaults.ViewMode, err = getViewMode() + case "singleclick": + defaults.SingleClick = v.GetBool(key) + case "aceeditortheme": + defaults.AceEditorTheme = v.GetString(key) case "perm.admin": - defaults.Perm.Admin, err = getBool(flags, flag.Name) + defaults.Perm.Admin = v.GetBool(key) case "perm.execute": - defaults.Perm.Execute, err = getBool(flags, flag.Name) + defaults.Perm.Execute = v.GetBool(key) case "perm.create": - defaults.Perm.Create, err = getBool(flags, flag.Name) + defaults.Perm.Create = v.GetBool(key) case "perm.rename": - defaults.Perm.Rename, err = getBool(flags, flag.Name) + defaults.Perm.Rename = v.GetBool(key) case "perm.modify": - defaults.Perm.Modify, err = getBool(flags, flag.Name) + defaults.Perm.Modify = v.GetBool(key) case "perm.delete": - defaults.Perm.Delete, err = getBool(flags, flag.Name) + defaults.Perm.Delete = v.GetBool(key) case "perm.share": - defaults.Perm.Share, err = getBool(flags, flag.Name) + defaults.Perm.Share = v.GetBool(key) case "perm.download": - defaults.Perm.Download, err = getBool(flags, flag.Name) + defaults.Perm.Download = v.GetBool(key) case "commands": - defaults.Commands, err = flags.GetStringSlice(flag.Name) + defaults.Commands = v.GetStringSlice(key) case "sorting.by": - defaults.Sorting.By, err = getString(flags, flag.Name) + defaults.Sorting.By = v.GetString(key) case "sorting.asc": - defaults.Sorting.Asc, err = getBool(flags, flag.Name) - case "hideDotfiles": - defaults.HideDotfiles, err = getBool(flags, flag.Name) + defaults.Sorting.Asc = v.GetBool(key) + case "hidedotfiles": + defaults.HideDotfiles = v.GetBool(key) } if err != nil { - visitErr = err + return err } } - if all { - flags.VisitAll(visit) - } else { - flags.Visit(visit) - } - return visitErr + return nil } diff --git a/cmd/users_add.go b/cmd/users_add.go index dce7ff98..8d0e96a6 100644 --- a/cmd/users_add.go +++ b/cmd/users_add.go @@ -2,6 +2,7 @@ package cmd import ( "github.com/spf13/cobra" + v "github.com/spf13/viper" "github.com/filebrowser/filebrowser/v2/users" ) @@ -21,7 +22,7 @@ var usersAddCmd = &cobra.Command{ if err != nil { return err } - err = getUserDefaults(cmd.Flags(), &s.Defaults, false) + err = getUserDefaults(&s.Defaults, false) if err != nil { return err } @@ -31,27 +32,12 @@ var usersAddCmd = &cobra.Command{ return err } - lockPassword, err := getBool(cmd.Flags(), "lockPassword") - if err != nil { - return err - } - - dateFormat, err := getBool(cmd.Flags(), "dateFormat") - if err != nil { - return err - } - - hideDotfiles, err := getBool(cmd.Flags(), "hideDotfiles") - if err != nil { - return err - } - user := &users.User{ Username: args[0], Password: password, - LockPassword: lockPassword, - DateFormat: dateFormat, - HideDotfiles: hideDotfiles, + LockPassword: v.GetBool("lockpassword"), + DateFormat: v.GetBool("dateformat"), + HideDotfiles: v.GetBool("hidedotfiles"), } s.Defaults.Apply(user) diff --git a/cmd/users_import.go b/cmd/users_import.go index 74353c2c..a7ad2cdf 100644 --- a/cmd/users_import.go +++ b/cmd/users_import.go @@ -7,6 +7,7 @@ import ( "strconv" "github.com/spf13/cobra" + v "github.com/spf13/viper" "github.com/filebrowser/filebrowser/v2/users" ) @@ -45,12 +46,7 @@ list or set it to 0.`, } } - replace, err := getBool(cmd.Flags(), "replace") - if err != nil { - return err - } - - if replace { + if v.GetBool("replace") { oldUsers, userImportErr := d.store.Users.Gets("") if userImportErr != nil { return userImportErr @@ -69,10 +65,7 @@ list or set it to 0.`, } } - overwrite, err := getBool(cmd.Flags(), "overwrite") - if err != nil { - return err - } + overwrite := v.GetBool("overwrite") for _, user := range list { onDB, err := d.store.Users.Get("", user.ID) diff --git a/cmd/users_update.go b/cmd/users_update.go index a939e605..0be8b276 100644 --- a/cmd/users_update.go +++ b/cmd/users_update.go @@ -2,6 +2,7 @@ package cmd import ( "github.com/spf13/cobra" + v "github.com/spf13/viper" "github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/users" @@ -23,15 +24,8 @@ options you want to change.`, Args: cobra.ExactArgs(1), RunE: python(func(cmd *cobra.Command, args []string, d *pythonData) error { username, id := parseUsernameOrID(args[0]) - flags := cmd.Flags() - password, err := getString(flags, "password") - if err != nil { - return err - } - newUsername, err := getString(flags, "username") - if err != nil { - return err - } + password := v.GetString("password") + newUsername := v.GetString("username") s, err := d.store.Settings.Get() if err != nil { @@ -61,7 +55,7 @@ options you want to change.`, Sorting: user.Sorting, Commands: user.Commands, } - err = getUserDefaults(flags, &defaults, false) + err = getUserDefaults(&defaults, false) if err != nil { return err } @@ -72,18 +66,9 @@ options you want to change.`, user.Perm = defaults.Perm user.Commands = defaults.Commands user.Sorting = defaults.Sorting - user.LockPassword, err = getBool(flags, "lockPassword") - if err != nil { - return err - } - user.DateFormat, err = getBool(flags, "dateFormat") - if err != nil { - return err - } - user.HideDotfiles, err = getBool(flags, "hideDotfiles") - if err != nil { - return err - } + user.LockPassword = v.GetBool("lockpassword") + user.DateFormat = v.GetBool("dateformat") + user.HideDotfiles = v.GetBool("hidedotfiles") if newUsername != "" { user.Username = newUsername diff --git a/cmd/utils.go b/cmd/utils.go index cc718341..373d6f7a 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -13,7 +13,7 @@ import ( "github.com/asdine/storm/v3" "github.com/spf13/cobra" - "github.com/spf13/pflag" + v "github.com/spf13/viper" yaml "gopkg.in/yaml.v3" "github.com/filebrowser/filebrowser/v2/settings" @@ -23,15 +23,8 @@ import ( const dbPerms = 0640 -func getString(flags *pflag.FlagSet, flag string) (string, error) { - return flags.GetString(flag) -} - -func getMode(flags *pflag.FlagSet, flag string) (fs.FileMode, error) { - s, err := getString(flags, flag) - if err != nil { - return 0, err - } +func getAndParseMode(param string) (fs.FileMode, error) { + s := v.GetString(param) b, err := strconv.ParseUint(s, 0, 32) if err != nil { return 0, err @@ -39,14 +32,6 @@ func getMode(flags *pflag.FlagSet, flag string) (fs.FileMode, error) { return fs.FileMode(b), nil } -func getBool(flags *pflag.FlagSet, flag string) (bool, error) { - return flags.GetBool(flag) -} - -func getUint(flags *pflag.FlagSet, flag string) (uint, error) { - return flags.GetUint(flag) -} - func generateKey() []byte { k, err := settings.GenerateKey() if err != nil { @@ -91,9 +76,14 @@ func dbExists(path string) (bool, error) { func python(fn pythonFunc, cfg pythonConfig) cobraFunc { return func(cmd *cobra.Command, args []string) error { + err := v.BindPFlags(cmd.Flags()) + if err != nil { + panic(err) + } + data := &pythonData{hadDB: true} - path := getStringParam(cmd.Flags(), "database") + path := v.GetString("database") absPath, err := filepath.Abs(path) if err != nil { panic(err)