From 79357c8ae2f09641bd125c14964cd96fddef5782 Mon Sep 17 00:00:00 2001 From: slawk0 Date: Thu, 6 Feb 2025 19:38:21 +0100 Subject: [PATCH] implemented new error handler, added getContacts route --- database/auth.go | 75 +++++++++++---------------- database/contacts.go | 111 +++++++++++++++++++++++++++++++++++----- handlers/auth.go | 9 +++- handlers/contacts.go | 84 ++++++++++++++++-------------- helpers/errorHandler.go | 59 +++++++++++++++++++++ main.go | 5 +- router/router.go | 1 + 7 files changed, 243 insertions(+), 101 deletions(-) create mode 100644 helpers/errorHandler.go diff --git a/database/auth.go b/database/auth.go index d7bf66d..79ad7f8 100644 --- a/database/auth.go +++ b/database/auth.go @@ -2,88 +2,71 @@ package database import ( "database/sql" - "fmt" + "errors" "github.com/google/uuid" _ "github.com/lib/pq" - "log" + "relay-server/helpers" ) -func GetUsers(db *sql.DB) ([]string, error) { - query := `SELECT username FROM accounts;` - - rows, err := db.Query(query) - if err != nil { - log.Fatal(err) - } - - var users []string - for rows.Next() { - var user string - if err := rows.Scan(&user); err != nil { - return nil, err - } - users = append(users, user) - } - return users, err -} - func CheckUserExists(db *sql.DB, username string) (bool, error) { - query := `SELECT COUNT(1) FROM accounts WHERE username= $1` + query := `SELECT COUNT(1) FROM accounts WHERE username = $1` - var count int - - err := db.QueryRow(query, username).Scan(&count) - - if err != nil { - return false, fmt.Errorf("error checking username exists: %w", err) + var exist bool + err := db.QueryRow(query, username).Scan(&exist) + if !errors.Is(err, sql.ErrNoRows) { + return false, helpers.NewError(helpers.ErrInternal, "Failed to check username", err) } - return count > 0, err + return exist, nil } func InsertUser(db *sql.DB, username string, passwordHash string) (string, error) { query := ` - INSERT INTO Accounts (username, password_hash) - VALUES ($1, $2) - RETURNING user_id; - ` + INSERT INTO Accounts (username, password_hash) + VALUES ($1, $2) + RETURNING user_id; + ` var userID string err := db.QueryRow(query, username, passwordHash).Scan(&userID) if err != nil { - return "", fmt.Errorf("error inserting user: %w", err) + return "", helpers.NewError(helpers.ErrInternal, "Failed to create user", err) } - fmt.Printf("Inserted user: %w", username) - - return userID, err + return userID, nil } func GetPasswordHash(db *sql.DB, username string) (string, error) { query := ` - SELECT password_hash FROM Accounts - WHERE LOWER(username) = LOWER($1); - ` + SELECT password_hash FROM Accounts + WHERE LOWER(username) = LOWER($1); + ` var passwordHash string err := db.QueryRow(query, username).Scan(&passwordHash) if err != nil { - return "", fmt.Errorf("error getting password hash: %w", err) + if errors.Is(err, sql.ErrNoRows) { + return "", helpers.NewError(helpers.ErrNotFound, "User not found", nil) + } + return "", helpers.NewError(helpers.ErrInternal, "Failed to get user credentials", err) } - return passwordHash, err + return passwordHash, nil } func GetUserID(db *sql.DB, username string) (uuid.UUID, error) { query := ` - SELECT user_id FROM Accounts - WHERE LOWER(username) = LOWER($1); - ` + SELECT user_id FROM Accounts + WHERE LOWER(username) = LOWER($1); + ` var userID uuid.UUID err := db.QueryRow(query, username).Scan(&userID) if err != nil { - return uuid.Nil, fmt.Errorf("error getting user id: %w", err) + if errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, helpers.NewError(helpers.ErrNotFound, "User not found", nil) + } + return uuid.Nil, helpers.NewError(helpers.ErrInternal, "Failed to get user ID", err) } return userID, nil } diff --git a/database/contacts.go b/database/contacts.go index 5b8cc8a..893112d 100644 --- a/database/contacts.go +++ b/database/contacts.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "github.com/google/uuid" + "log" + "relay-server/helpers" "relay-server/model" ) @@ -108,7 +110,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse if errors.Is(err, sql.ErrNoRows) { conversationID = uuid.Nil } - return nil, fmt.Errorf("error finding existing conversation: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error finding existing conversation: %w", err)) } // if conversation for themselves don't exist @@ -120,7 +122,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse ` err := db.QueryRow(createConversationQuery).Scan(&conversationID) if err != nil { - return nil, fmt.Errorf("error creating conversation for self-contact: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "Internal server error", fmt.Errorf("error creating conversation for self-contact: %w", err)) } createMembershipQuery := ` @@ -150,7 +152,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse if errors.Is(err, sql.ErrNoRows) { conversationID = uuid.Nil } - return nil, fmt.Errorf("error finding existing conversation: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error finding existing conversation: %w", err)) } if conversationID != uuid.Nil { // Create a new conversation between user and contact @@ -161,7 +163,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse ` err := db.QueryRow(createConversationQuery).Scan(&conversationID) if err != nil { - return nil, fmt.Errorf("error creating conversation: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error creating conversation: %w", err)) } createMembershipQuery := ` @@ -172,14 +174,14 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse res, err := db.Exec(createMembershipQuery, conversationID, userID, contactID) if err != nil { - return nil, fmt.Errorf("error creating membership: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error creating membership: %w", err)) } rowsAffected, err := res.RowsAffected() if err != nil { - return nil, fmt.Errorf("error checking membership creation: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error checking membership creation: %w", err)) } if rowsAffected == 0 { - return nil, fmt.Errorf("error creating membership %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error creating membership %w", err)) } } } @@ -187,12 +189,12 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse } insertedContact, err := InsertContactByID(db, contactID, conversationID) if err != nil || insertedContact.UserID == uuid.Nil { - return nil, fmt.Errorf("error inserting contact by id: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", err) } latestMessage, err := GetLatestMessage(db, conversationID) if err != nil { - return nil, fmt.Errorf("error getting latest message: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", err) } contact = model.Contact{ @@ -219,9 +221,9 @@ func InsertContactByID(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) ( var contact model.Contact err := db.QueryRow(checkQuery, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID) if err == nil { - return nil, fmt.Errorf("contact already exists") + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("contact already exists")) } else if !errors.Is(err, sql.ErrNoRows) { - return nil, fmt.Errorf("error checking contact existence: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error checking contact existence: %w", err)) } insertQuery := ` @@ -232,7 +234,7 @@ func InsertContactByID(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) ( err = db.QueryRow(insertQuery, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID) if err != nil { - return nil, fmt.Errorf("error inserting contact: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error inserting contact: %w", err)) } fmt.Printf("Successfully inserted contact by id: %v", conversationID) @@ -257,7 +259,90 @@ func GetLatestMessage(db *sql.DB, conversationId uuid.UUID) (*model.Contact, err err := db.QueryRow(query, conversationId).Scan(&latestMessage.LastMessageID, &latestMessage.LastMessage, &latestMessage.LastMessageTime, &latestMessage.LastMessageSender) if err != nil { - return nil, fmt.Errorf("error getting latest message: %w", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error getting latest message: %w", err)) } return &latestMessage, nil } + +func GetContacts(db *sql.DB, userID uuid.UUID) ([]*model.Contact, error) { + contactsQuery := ` + WITH DirectContacts AS ( + SELECT DISTINCT ON (c.conversation_id) + c.contact_id AS id, + a.user_id AS user_id, + a.username AS username, + conv.last_active, + c.conversation_id, + conv.conversation_type AS type, + requesting_member.last_read_message_id + FROM Contacts c + JOIN Conversations conv ON c.conversation_id = conv.conversation_id + JOIN Memberships requesting_member + ON requesting_member.conversation_id = conv.conversation_id + AND requesting_member.user_id = $1 + JOIN Memberships other_member + ON other_member.conversation_id = conv.conversation_id + JOIN Accounts a ON a.user_id = other_member.user_id + WHERE c.user_id = $1 + AND conv.conversation_type = 'direct' + AND ( + a.user_id != $1 + OR (SELECT COUNT(*) FROM Memberships WHERE conversation_id = c.conversation_id) = 1 + ) + ), + GroupContacts AS ( + SELECT DISTINCT ON (c.conversation_id) + c.contact_id AS id, + NULL::uuid AS user_id, + conv.name AS username, + conv.last_active, + c.conversation_id, + conv.conversation_type AS type, + m.last_read_message_id + FROM Contacts c + JOIN Conversations conv ON c.conversation_id = conv.conversation_id + JOIN Memberships m + ON m.conversation_id = conv.conversation_id + AND m.user_id = $1 + WHERE c.user_id = $1 + AND conv.conversation_type = 'group' + ) + SELECT * FROM DirectContacts + UNION ALL + SELECT * FROM GroupContacts + ORDER BY last_active DESC NULLS LAST; + ` + + rows, err := db.Query(contactsQuery, userID) + if err != nil { + log.Println("Failed to get contacts:", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error getting contacts: %w", err)) + } + + var contacts []*model.Contact + for rows.Next() { + contact := &model.Contact{} + err := rows.Scan(&contact.ID, &contact.UserID, &contact.Username, &contact.ConversationID, &contact.Type) + if err != nil { + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error scanning contact: %w", err)) + } + + latestMessage, err := GetLatestMessage(db, contact.ConversationID) + if err != nil { + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error getting latest message: %w", err)) + } + + contact.LastMessageID = latestMessage.LastMessageID + contact.LastMessage = latestMessage.LastMessage + contact.LastMessageTime = latestMessage.LastMessageTime + contact.LastMessageSender = latestMessage.LastMessageSender + + contacts = append(contacts, contact) + } + + if err = rows.Err(); err != nil { + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("error iterating over contacts: %w", err)) + } + + return contacts, nil +} diff --git a/handlers/auth.go b/handlers/auth.go index c4df198..9e21a43 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -1,6 +1,7 @@ package handlers import ( + "errors" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" @@ -127,8 +128,12 @@ func Login(c *fiber.Ctx) error { userID, err := database.GetUserID(db, u.Username) if err != nil { - log.Print(err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "internal server error"}) + var e *helpers.Error + if errors.As(err, &e) { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "user does not exist"}) + } + log.Println(err) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"}) } // Generate token with user id and username token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ diff --git a/handlers/contacts.go b/handlers/contacts.go index f280155..5127270 100644 --- a/handlers/contacts.go +++ b/handlers/contacts.go @@ -1,18 +1,13 @@ package handlers import ( - "database/sql" - "errors" - "fmt" "github.com/gofiber/fiber/v2" "github.com/google/uuid" - "log" "relay-server/database" "relay-server/helpers" ) func DeleteContact(c *fiber.Ctx) error { - type params struct { ContactID uuid.UUID `params:"contact_id"` ConversationID uuid.UUID `params:"conversation_id"` @@ -20,74 +15,85 @@ func DeleteContact(c *fiber.Ctx) error { p := new(params) if err := c.ParamsParser(p); err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid params"}) + return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) } - db := database.DB - if p.ContactID == uuid.Nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "contact_id is empty"}) + return helpers.NewError(helpers.ErrInvalidInput, "contact ID is empty", nil) } if p.ConversationID == uuid.Nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "conversation_id is empty"}) + return helpers.NewError(helpers.ErrInvalidInput, "conversation ID is empty", nil) } - msg, err := database.DeleteContact(db, p.ContactID, p.ConversationID) + msg, err := database.DeleteContact(database.DB, p.ContactID, p.ConversationID) if err != nil { - log.Println(err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) + return helpers.NewError(helpers.ErrInternal, "Failed to delete contact", err) } if msg != "" { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": msg}) + return helpers.NewError(helpers.ErrInvalidInput, msg, nil) } - fmt.Println("Contact deleted") return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Contact deleted"}) } func InsertContact(c *fiber.Ctx) error { - userID, err := uuid.Parse(c.Locals("userID").(string)) - if err != nil { - log.Println(err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"}) - } type params struct { ContactUsername string `params:"contact_username"` } - p := new(params) - - if err := c.ParamsParser(p); err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid params"}) + userIDVal := c.Locals("userID") + userIDStr, ok := userIDVal.(string) + if !ok { + return helpers.NewError(helpers.ErrInternal, "Internal server error", nil) } - db := database.DB + userID, err := uuid.Parse(userIDStr) + if err != nil { + return helpers.NewError(helpers.ErrInvalidInput, "Invalid user ID format", err) + } + + p := new(params) + if err := c.ParamsParser(p); err != nil { + return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) + } if p.ContactUsername == "" { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "contact_username is empty"}) + return helpers.NewError(helpers.ErrInvalidInput, "contact username is empty", nil) } if !helpers.IsValidUsername(p.ContactUsername) { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "username is invalid"}) + return helpers.NewError(helpers.ErrInvalidInput, "username is invalid", nil) } - contactID, err := database.GetUserID(db, p.ContactUsername) - + contactID, err := database.GetUserID(database.DB, p.ContactUsername) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "user does not exist"}) - } - log.Println(err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"}) + return err } - newContacts, err := database.InsertContact(db, userID, contactID, p.ContactUsername) + newContacts, err := database.InsertContact(database.DB, userID, contactID, p.ContactUsername) if err != nil { - log.Println(err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"}) + return err } - log.Println("Contact added") return c.Status(fiber.StatusOK).JSON(newContacts) - +} + +func GetContacts(c *fiber.Ctx) error { + userIDVal := c.Locals("userID") + userIDStr, ok := userIDVal.(string) + if !ok { + return helpers.NewError(helpers.ErrInternal, "Invalid user session", nil) + } + + userID, err := uuid.Parse(userIDStr) + if err != nil { + return helpers.NewError(helpers.ErrInvalidInput, "Invalid user ID format", err) + } + + contacts, err := database.GetContacts(database.DB, userID) + if err != nil { + return err + } + + return c.Status(fiber.StatusOK).JSON(contacts) } diff --git a/helpers/errorHandler.go b/helpers/errorHandler.go new file mode 100644 index 0000000..e8a5a69 --- /dev/null +++ b/helpers/errorHandler.go @@ -0,0 +1,59 @@ +package helpers + +import ( + "errors" + "github.com/gofiber/fiber/v2" +) + +type ErrorCode string + +const ( + ErrInternal ErrorCode = "INTERNAL_ERROR" + ErrNotFound ErrorCode = "NOT_FOUND" + ErrInvalidInput ErrorCode = "INVALID_INPUT" +) + +type Error struct { + Code ErrorCode + UserMessage string + InternalError error + StatusCode int +} + +func (e *Error) Error() string { + if e.InternalError != nil { + return e.InternalError.Error() + } + return e.UserMessage +} + +func (e *Error) Unwrap() error { + return e.InternalError +} + +func NewError(code ErrorCode, userMsg string, internalErr error) *Error { + statusCode := fiber.StatusInternalServerError + switch code { + case ErrNotFound: + statusCode = fiber.StatusNotFound + case ErrInvalidInput: + statusCode = fiber.StatusBadRequest + case ErrInternal: + statusCode = fiber.StatusInternalServerError + } + + return &Error{ + Code: code, + UserMessage: userMsg, + InternalError: internalErr, + StatusCode: statusCode, + } +} + +func ErrorHandler(c *fiber.Ctx, err error) error { + var e *Error + if errors.As(err, &e) { + return c.Status(e.StatusCode).JSON(fiber.Map{"message": e.UserMessage}) + } + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "An unexpected error occurred"}) +} diff --git a/main.go b/main.go index f4fd297..31086ba 100644 --- a/main.go +++ b/main.go @@ -4,11 +4,14 @@ import ( "github.com/gofiber/fiber/v2" "log" "relay-server/database" + "relay-server/helpers" "relay-server/router" ) func main() { - app := fiber.New() + app := fiber.New(fiber.Config{ + ErrorHandler: helpers.ErrorHandler, + }) db, err := database.Init() if err != nil { log.Fatal("Failed to initialize database") diff --git a/router/router.go b/router/router.go index 35a931a..33c3487 100644 --- a/router/router.go +++ b/router/router.go @@ -24,4 +24,5 @@ func SetupRoutes(app *fiber.App) { contacts := chat.Group("/contacts", middleware.Protected(), logger.New()) contacts.Delete("/:contact_id/:conversation_id", handlers.DeleteContact) contacts.Post("/:contact_username", handlers.InsertContact) + contacts.Get("/", handlers.GetContacts) }