From bbe9fd878921a646d711fef5cfcf6f090b88b992 Mon Sep 17 00:00:00 2001 From: wwt Date: Sun, 24 Dec 2023 21:57:48 +0800 Subject: [PATCH] typo --- storage/sql/sql.go | 8 ++++---- storage/sql/users.go | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/storage/sql/sql.go b/storage/sql/sql.go index aaaf72f4..1273ffdb 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -50,7 +50,7 @@ func getNameQuote(dbType string) string { return "\"" } -// for mysql, it is `` +// for mysql, it is “ // for postgres and sqlite, it is "" func quoteName(dbType string, name string) string { q := getNameQuote(dbType) @@ -64,14 +64,14 @@ func placeHolder(dbType string, index int) string { if index <= 0 { panic("the placeholder index should >= 1") } - if dbType == "postgres" { + if dbType == "postgres" || dbType == "postgresql" { return fmt.Sprintf("$%d", index) } return "?" } func IsDBPath(path string) bool { - prefixes := []string{"sqlite3", "postgres", "mysql"} + prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"} for _, prefix := range prefixes { if strings.HasPrefix(path, prefix+"://") { return true @@ -84,7 +84,7 @@ func OpenDB(path string) (*sql.DB, error) { if val, ok := dbRecords[path]; ok { return val.db, nil } - prefixes := []string{"sqlite3", "postgres", "mysql"} + prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"} for _, prefix := range prefixes { if strings.HasPrefix(path, prefix) { db, err := connectDB(prefix, path) diff --git a/storage/sql/users.go b/storage/sql/users.go index 1fd4b56e..fc72ddf5 100644 --- a/storage/sql/users.go +++ b/storage/sql/users.go @@ -117,7 +117,7 @@ func createAdminUser() users.User { func InitUserTable(db *sql.DB, dbType string) error { primaryKey := "integer primary key" - if dbType == "postgres" { + if dbType == "postgres" || dbType == "postgresql" { primaryKey = "serial primary key" } else if dbType == "mysql" { primaryKey = "int unsigned primary key auto_increment" @@ -133,6 +133,18 @@ func newUsersBackend(db *sql.DB, dbType string) usersBackend { return usersBackend{db: db, dbType: dbType} } +func (s usersBackend) IsPostgresql() bool { + return s.dbType == "postgres" || s.dbType == "postgresql" +} + +func (s usersBackend) IsMysql() bool { + return s.dbType == "mysql" +} + +func (s usersBackend) IsSqlite() bool { + return s.dbType == "sqlite3" +} + func (s usersBackend) GetBy(i interface{}) (*users.User, error) { columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"} columnsStr := strings.Join(columns, ",") @@ -272,7 +284,7 @@ func (s usersBackend) insertUser(user *users.User) error { columnStr := strings.Join(columns, ",") specStr := strings.Join(specs, ",") sqlFormat := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoteName(s.dbType, UsersTable), columnStr, specStr) - if s.dbType == "postgres" { + if s.IsPostgresql() { sqlFormat = sqlFormat + " RETURNING id;" } sql := fmt.Sprintf( @@ -291,7 +303,7 @@ func (s usersBackend) insertUser(user *users.User) error { boolToString(user.DateFormat), boolToString(user.SingleClick), ) - if s.dbType == "postgres" { + if s.IsPostgresql() { id := uint(0) err := s.db.QueryRow(sql).Scan(&id) if !checkError(err, "Fail to insert user") {