91 lines
1.9 KiB
Go
91 lines
1.9 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
_ "github.com/lib/pq"
|
|
"log"
|
|
)
|
|
|
|
func GetUsers(db *sql.DB) ([]string, error) {
|
|
query := `SELECT username FROM accounts;`
|
|
|
|
rows, err := db.Query(query)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
var users []string
|
|
for rows.Next() {
|
|
var user string
|
|
if err := rows.Scan(&user); err != nil {
|
|
return nil, err
|
|
}
|
|
users = append(users, user)
|
|
}
|
|
return users, err
|
|
}
|
|
|
|
func CheckUserExists(db *sql.DB, username string) (bool, error) {
|
|
query := `SELECT COUNT(1) FROM accounts WHERE username= $1`
|
|
|
|
var count int
|
|
|
|
err := db.QueryRow(query, username).Scan(&count)
|
|
|
|
if err != nil {
|
|
return false, fmt.Errorf("error checking username exists: %v", err)
|
|
}
|
|
|
|
return count > 0, err
|
|
}
|
|
|
|
func InsertUser(db *sql.DB, username string, passwordHash string) (string, error) {
|
|
query := `
|
|
INSERT INTO Accounts (username, password_hash)
|
|
VALUES ($1, $2)
|
|
RETURNING user_id;
|
|
`
|
|
|
|
var userId string
|
|
err := db.QueryRow(query, username, passwordHash).Scan(&userId)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error inserting user: %v", err)
|
|
}
|
|
|
|
log.Printf("Inserted user: %v", username)
|
|
|
|
return userId, err
|
|
}
|
|
|
|
func GetPasswordHash(db *sql.DB, username string) (string, error) {
|
|
query := `
|
|
SELECT password_hash FROM Accounts
|
|
WHERE LOWER(username) = LOWER($1);
|
|
`
|
|
|
|
var passwordHash string
|
|
err := db.QueryRow(query, username).Scan(&passwordHash)
|
|
if err != nil {
|
|
fmt.Printf("error getting password hash: %v\n", err)
|
|
return "", fmt.Errorf("error getting password hash: %v", err)
|
|
}
|
|
|
|
return passwordHash, err
|
|
}
|
|
|
|
func GetUserId(db *sql.DB, username string) (string, error) {
|
|
query := `
|
|
SELECT user_id FROM Accounts
|
|
WHERE LOWER(username) = $1;
|
|
`
|
|
|
|
var dbUsername string
|
|
err := db.QueryRow(query, username).Scan(&dbUsername)
|
|
if err != nil {
|
|
fmt.Printf("Error getting password: %v\n", err)
|
|
return "", fmt.Errorf("error getting user id: %v", err)
|
|
}
|
|
return dbUsername, err
|
|
}
|