diff --git a/database/auth.go b/database/auth.go index cc23e2f..01e7c64 100644 --- a/database/auth.go +++ b/database/auth.go @@ -5,7 +5,7 @@ import ( "errors" "github.com/google/uuid" _ "github.com/lib/pq" - "relay-server/helpers" + "relay-server/utils" ) func CheckUserExists(db *sql.DB, username string) (bool, error) { @@ -14,7 +14,7 @@ func CheckUserExists(db *sql.DB, username string) (bool, error) { var count int err := db.QueryRow(query, username).Scan(&count) if err != nil { - return false, helpers.NewError(helpers.ErrInternal, "Failed to check user existence", err) + return false, utils.NewError(utils.ErrInternal, "Failed to check user existence", err) } return count > 0, nil @@ -30,7 +30,7 @@ func InsertUser(db *sql.DB, username string, passwordHash string) (string, error var userID string err := db.QueryRow(query, username, passwordHash).Scan(&userID) if err != nil { - return "", helpers.NewError(helpers.ErrInternal, "Failed to create user", err) + return "", utils.NewError(utils.ErrInternal, "Failed to create user", err) } return userID, nil @@ -46,9 +46,9 @@ func GetPasswordHash(db *sql.DB, username string) (string, error) { err := db.QueryRow(query, username).Scan(&passwordHash) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return "", helpers.NewError(helpers.ErrNotFound, "User not found", nil) + return "", utils.NewError(utils.ErrNotFound, "User not found", nil) } - return "", helpers.NewError(helpers.ErrInternal, "Failed to get user credentials", err) + return "", utils.NewError(utils.ErrInternal, "Failed to get user credentials", err) } return passwordHash, nil @@ -64,9 +64,9 @@ func GetUserID(db *sql.DB, username string) (uuid.UUID, error) { err := db.QueryRow(query, username).Scan(&userID) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return uuid.Nil, helpers.NewError(helpers.ErrNotFound, "User not found", nil) + return uuid.Nil, utils.NewError(utils.ErrNotFound, "User not found", nil) } - return uuid.Nil, helpers.NewError(helpers.ErrInternal, "Failed to get user ID", err) + return uuid.Nil, utils.NewError(utils.ErrInternal, "Failed to get user ID", err) } return userID, nil } diff --git a/database/contacts.go b/database/contacts.go index 44f6aee..118d286 100644 --- a/database/contacts.go +++ b/database/contacts.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" "github.com/google/uuid" - "relay-server/helpers" "relay-server/model" + "relay-server/utils" "strings" ) @@ -18,9 +18,9 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) error ).Scan(&conversationType) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return helpers.NewError(helpers.ErrNotFound, "no contacts found for this id", fmt.Errorf("no conversation found with id: %s", conversationID)) + return utils.NewError(utils.ErrNotFound, "no contacts found for this id", fmt.Errorf("no conversation found with id: %s", conversationID)) } - return helpers.NewError(helpers.ErrInternal, "Failed to check conversation", err) + return utils.NewError(utils.ErrInternal, "Failed to check conversation", err) } if conversationType == "group" { @@ -30,15 +30,15 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) error conversationID, userID, ) if err != nil { - return helpers.NewError(helpers.ErrInternal, "Failed to delete contact", err) + return utils.NewError(utils.ErrInternal, "Failed to delete contact", err) } rowsAffected, err := res.RowsAffected() if err != nil { - return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to verify contact deletion: %w", err)) + return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to verify contact deletion: %w", err)) } if rowsAffected == 0 { - return helpers.NewError(helpers.ErrNotFound, fmt.Sprintf("no matching contact found with conversation id: %s, user id: %s", conversationID, userID), nil) + return utils.NewError(utils.ErrNotFound, fmt.Sprintf("no matching contact found with conversation id: %s, user id: %s", conversationID, userID), nil) } // Delete from Memberships @@ -47,15 +47,15 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) error conversationID, userID, ) if err != nil { - return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to delete membership: %w", err)) + return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to delete membership: %w", err)) } rowsAffected, err = res.RowsAffected() if err != nil { - return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to verify membership deletion: %w", err)) + return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to verify membership deletion: %w", err)) } if rowsAffected == 0 { - return helpers.NewError(helpers.ErrNotFound, "No membership found", err) + return utils.NewError(utils.ErrNotFound, "No membership found", err) } } else { res, err := db.Exec( @@ -63,15 +63,15 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) error userID, conversationID, ) if err != nil { - return helpers.NewError(helpers.ErrInternal, "Failed to delete contact", err) + return utils.NewError(utils.ErrInternal, "Failed to delete contact", err) } rowsAffected, err := res.RowsAffected() if err != nil { - return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to verify contact deletion: %w", err)) + return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to verify contact deletion: %w", err)) } if rowsAffected == 0 { - return helpers.NewError(helpers.ErrNotFound, fmt.Sprintf("no matching contact found with user id: %s, conversation id: %s", userID, conversationID), nil) + return utils.NewError(utils.ErrNotFound, fmt.Sprintf("no matching contact found with user id: %s, conversation id: %s", userID, conversationID), nil) } } @@ -94,7 +94,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse `, userID).Scan(&conversationID) if err != nil && !errors.Is(err, sql.ErrNoRows) { - return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check existing conversation: %w", err)) + return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to check existing conversation: %w", err)) } if conversationID == uuid.Nil { @@ -104,7 +104,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse RETURNING conversation_id; `).Scan(&conversationID) if err != nil { - return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to create conversation: %w", err)) + return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to create conversation: %w", err)) } _, err = db.Exec(` @@ -113,7 +113,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse ON CONFLICT (conversation_id, user_id) DO NOTHING; `, conversationID, userID) if err != nil { - return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to create membership: %w", err)) + return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to create membership: %w", err)) } } } else { @@ -128,7 +128,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse `, userID, contactID).Scan(&conversationID) if err != nil && !errors.Is(err, sql.ErrNoRows) { - return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check existing conversation %w", err)) + return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to check existing conversation %w", err)) } if conversationID == uuid.Nil { @@ -138,7 +138,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse RETURNING conversation_id; `).Scan(&conversationID) if err != nil { - return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to create conversation: %w", err)) + return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to create conversation: %w", err)) } _, err = db.Exec(` @@ -147,7 +147,7 @@ func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUse ON CONFLICT (conversation_id, user_id) DO NOTHING; `, conversationID, userID, contactID) if err != nil { - return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to create memberships: %w", err)) + return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to create memberships: %w", err)) } } } @@ -186,9 +186,9 @@ func InsertContactByID(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) ( `, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID) if err == nil { - return &model.Contact{}, helpers.NewError(helpers.ErrInvalidInput, "Contact already exists", nil) + return &model.Contact{}, utils.NewError(utils.ErrInvalidInput, "Contact already exists", nil) } else if !errors.Is(err, sql.ErrNoRows) { - return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check contact existence: %w", err)) + return &model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to check contact existence: %w", err)) } // Insert new contact @@ -199,7 +199,7 @@ func InsertContactByID(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) ( `, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID) if err != nil { - return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "Failed to create contact", err) + return &model.Contact{}, utils.NewError(utils.ErrInternal, "Failed to create contact", err) } return &contact, nil @@ -230,7 +230,7 @@ func GetLatestMessage(db *sql.DB, conversationId uuid.UUID) (*model.Contact, err if errors.Is(err, sql.ErrNoRows) { return &model.Contact{}, nil } - return &model.Contact{}, helpers.NewError(helpers.ErrInternal, "Failed to get latest message", fmt.Errorf("failed to get latest message: %w", err)) + return &model.Contact{}, utils.NewError(utils.ErrInternal, "Failed to get latest message", fmt.Errorf("failed to get latest message: %w", err)) } return &latestMessage, nil @@ -287,7 +287,7 @@ func GetContacts(db *sql.DB, userID uuid.UUID) ([]*model.Contact, error) { rows, err := db.Query(contactsQuery, userID) if err != nil { - return []*model.Contact{}, helpers.NewError(helpers.ErrInternal, "Failed to get contacts", fmt.Errorf("failed to get contacts: %w", err)) + return []*model.Contact{}, utils.NewError(utils.ErrInternal, "Failed to get contacts", fmt.Errorf("failed to get contacts: %w", err)) } defer rows.Close() @@ -296,7 +296,7 @@ func GetContacts(db *sql.DB, userID uuid.UUID) ([]*model.Contact, error) { contact := &model.Contact{} err := rows.Scan(&contact.ID, &contact.UserID, &contact.Username, &contact.LastActive, &contact.ConversationID, &contact.Type, &contact.LastReadMessageID) if err != nil { - return []*model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to scan contact: %w", err)) + return []*model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to scan contact: %w", err)) } latestMessage, err := GetLatestMessage(db, contact.ConversationID) @@ -313,7 +313,7 @@ func GetContacts(db *sql.DB, userID uuid.UUID) ([]*model.Contact, error) { } if err = rows.Err(); err != nil { - return []*model.Contact{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process contacts: %w", err)) + return []*model.Contact{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to process contacts: %w", err)) } return contacts, nil @@ -329,7 +329,7 @@ func ContactSuggestion(db *sql.DB, contactUsername string) ([]string, error) { rows, err := db.Query(query, "%"+strings.ToLower(contactUsername)+"%") if err != nil && !errors.Is(err, sql.ErrNoRows) { - return []string{}, helpers.NewError(helpers.ErrInternal, "Failed to get contact suggestions", fmt.Errorf("failed to get contact suggestions: %w", err)) + return []string{}, utils.NewError(utils.ErrInternal, "Failed to get contact suggestions", fmt.Errorf("failed to get contact suggestions: %w", err)) } defer rows.Close() @@ -338,12 +338,12 @@ func ContactSuggestion(db *sql.DB, contactUsername string) ([]string, error) { var suggestion string err := rows.Scan(&suggestion) if err != nil { - return []string{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to scan contact suggestion: %w", err)) + return []string{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to scan contact suggestion: %w", err)) } suggestions = append(suggestions, suggestion) } if err = rows.Err(); err != nil { - return []string{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process suggestions: %w", err)) + return []string{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to process suggestions: %w", err)) } return suggestions, nil } diff --git a/database/groups.go b/database/groups.go index c281f39..d0cc2dc 100644 --- a/database/groups.go +++ b/database/groups.go @@ -4,8 +4,8 @@ import ( "database/sql" "fmt" "github.com/google/uuid" - "relay-server/helpers" "relay-server/model" + "relay-server/utils" ) func CreateGroup(db *sql.DB, groupName string, userID uuid.UUID) (uuid.UUID, error) { @@ -25,14 +25,14 @@ func CreateGroup(db *sql.DB, groupName string, userID uuid.UUID) (uuid.UUID, err var groupID uuid.UUID err := db.QueryRow(createConversationQuery, groupName).Scan(&groupID) if err != nil { - return uuid.Nil, helpers.NewError(helpers.ErrInternal, "Failed to create group", fmt.Errorf("failed to create group: %w", err)) + return uuid.Nil, utils.NewError(utils.ErrInternal, "Failed to create group", fmt.Errorf("failed to create group: %w", err)) } // Insert group admin (make user that created the group an admin) var grantedAt string err = db.QueryRow(insertGroupAdminQuery, groupID, userID, userID).Scan(&grantedAt) if err != nil { - return uuid.Nil, helpers.NewError(helpers.ErrInternal, "Failed to create group", fmt.Errorf("failed to insert group admin: %w", err)) + return uuid.Nil, utils.NewError(utils.ErrInternal, "Failed to create group", fmt.Errorf("failed to insert group admin: %w", err)) } // Add self as group member @@ -44,7 +44,7 @@ func CreateGroup(db *sql.DB, groupName string, userID uuid.UUID) (uuid.UUID, err // Insert group contact _, err = InsertContactByID(db, userID, groupID) if err != nil { - return groupID, helpers.NewError(helpers.ErrInternal, "Failed to create group contact", fmt.Errorf("failed to insert group contact: %w", err)) + return groupID, utils.NewError(utils.ErrInternal, "Failed to create group contact", fmt.Errorf("failed to insert group contact: %w", err)) } return groupID, nil @@ -60,7 +60,7 @@ func AddMemberToGroup(db *sql.DB, userID uuid.UUID, groupID uuid.UUID) (uuid.UUI var memberID uuid.UUID err := db.QueryRow(query, groupID, userID).Scan(&memberID) if err != nil { - return uuid.Nil, helpers.NewError(helpers.ErrInternal, "Failed to add member to group", fmt.Errorf("failed to add member to group: %w", err)) + return uuid.Nil, utils.NewError(utils.ErrInternal, "Failed to add member to group", fmt.Errorf("failed to add member to group: %w", err)) } return memberID, nil } @@ -83,7 +83,7 @@ func GetMembers(db *sql.DB, groupID uuid.UUID) ([]*model.Member, error) { rows, err := db.Query(query, groupID) if err != nil { - return []*model.Member{}, helpers.NewError(helpers.ErrInternal, "Failed to get members", fmt.Errorf("failed to get members: %w", err)) + return []*model.Member{}, utils.NewError(utils.ErrInternal, "Failed to get members", fmt.Errorf("failed to get members: %w", err)) } defer rows.Close() @@ -92,13 +92,13 @@ func GetMembers(db *sql.DB, groupID uuid.UUID) ([]*model.Member, error) { var member model.Member err = rows.Scan(&member.UserID, &member.Username, &member.IsAdmin, &member.IsOwner) if err != nil { - return []*model.Member{}, helpers.NewError(helpers.ErrInternal, "Failed to get members", fmt.Errorf("failed to scan member: %w", err)) + return []*model.Member{}, utils.NewError(utils.ErrInternal, "Failed to get members", fmt.Errorf("failed to scan member: %w", err)) } members = append(members, &member) } if err = rows.Err(); err != nil { - return []*model.Member{}, helpers.NewError(helpers.ErrInternal, "Failed to get members", fmt.Errorf("error iterating members: %w", err)) + return []*model.Member{}, utils.NewError(utils.ErrInternal, "Failed to get members", fmt.Errorf("error iterating members: %w", err)) } return members, nil @@ -113,7 +113,7 @@ func IsAdmin(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (bool, erro var count int err := db.QueryRow(query, userID, conversationID).Scan(&count) if err != nil { - return false, helpers.NewError(helpers.ErrInternal, "Failed to check admin status", fmt.Errorf("failed to check admin status: %w", err)) + return false, utils.NewError(utils.ErrInternal, "Failed to check admin status", fmt.Errorf("failed to check admin status: %w", err)) } return count > 0, nil } @@ -127,7 +127,83 @@ func IsMember(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (bool, err var count int err := db.QueryRow(query, userID, conversationID).Scan(&count) if err != nil { - return false, helpers.NewError(helpers.ErrInternal, "Failed to check membership", fmt.Errorf("failed to check membership: %w", err)) + return false, utils.NewError(utils.ErrInternal, "Failed to check membership", fmt.Errorf("failed to check membership: %w", err)) } return count > 0, nil } + +func GetUserConversations(db *sql.DB, userID string) ([]string, error) { + query := ` + SELECT DISTINCT m.conversation_id + FROM Memberships m + JOIN Conversations c ON m.conversation_id = c.conversation_id + WHERE m.user_id = $1 + AND c.conversation_type = 'group'; + ` + + var conversations []string + + rows, err := db.Query(query, userID) + if err != nil { + return []string{}, fmt.Errorf("failed to get user conversations: %w", err) + } + defer rows.Close() + + for rows.Next() { + var conversationID string + err = rows.Scan(&conversationID) + if err != nil { + return []string{}, fmt.Errorf("failed to scan conversation: %w", err) + } + conversations = append(conversations, conversationID) + } + if err := rows.Err(); err != nil { + return []string{}, fmt.Errorf("error iterating conversations: %w", err) + } + + return conversations, nil +} + +func RemoveUserFromGroup(db *sql.DB, userID uuid.UUID, groupID uuid.UUID) (string, error) { + removeUserFromGroupQuery := ` + DELETE FROM Memberships + WHERE conversation_id = $1 AND user_id = $2; + ` + + isOwner, err := IsGroupOwner(db, userID, groupID) + if err != nil { + return "Failed to remove user from group", fmt.Errorf("failed to check if user is group owner: %w", err) + } + if !isOwner { + return "Cannot remove group owner", nil + } + res, err := db.Exec(removeUserFromGroupQuery, groupID, userID) + if err != nil { + return "Failed to remove user from group", fmt.Errorf("failed to remove user from group: %w", err) + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return "Failed to remove user from group", fmt.Errorf("failed to get rows affected: %w", err) + } + if rowsAffected == 0 { + return "User is not a member of the group", nil + } + return "Successfully removed user from group", nil +} + +func IsGroupOwner(db *sql.DB, userID uuid.UUID, groupID uuid.UUID) (bool, error) { + query := ` + SELECT EXISTS ( + SELECT 1 FROM GroupAdmins + WHERE conversation_id = $1 + AND user_id = $2 + AND is_owner = true + ) AS is_owner; + ` + var isOwner bool + err := db.QueryRow(query, groupID, userID).Scan(&isOwner) + if err != nil { + return false, fmt.Errorf("failed to check if user is group owner: %w", err) + } + return isOwner, nil +} diff --git a/database/messages.go b/database/messages.go index ddcd274..242ea2d 100644 --- a/database/messages.go +++ b/database/messages.go @@ -4,8 +4,8 @@ import ( "database/sql" "fmt" "github.com/google/uuid" - "relay-server/helpers" "relay-server/model" + "relay-server/utils" ) func GetMessages(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, limit int, cursor int) ([]*model.Message, error) { @@ -33,20 +33,20 @@ func GetMessages(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, limit i ` rows, err := db.Query(query, conversationID, cursor, limit) if err != nil { - return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err)) + return []*model.Message{}, utils.NewError(utils.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err)) } defer rows.Close() for rows.Next() { message := &model.Message{} - err = rows.Scan(&message.MessageID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender) + err = rows.Scan(&message.ID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender) if err != nil { - return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err)) + return []*model.Message{}, utils.NewError(utils.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err)) } messages = append(messages, message) } if err = rows.Err(); err != nil { - return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err)) + return []*model.Message{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err)) } } else { @@ -65,20 +65,20 @@ func GetMessages(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, limit i ` rows, err := db.Query(query, conversationID, cursor, limit) if err != nil { - return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err)) + return []*model.Message{}, utils.NewError(utils.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err)) } defer rows.Close() for rows.Next() { message := &model.Message{} - err = rows.Scan(&message.MessageID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender) + err = rows.Scan(&message.ID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender) if err != nil { - return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err)) + return []*model.Message{}, utils.NewError(utils.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err)) } messages = append(messages, message) } if err = rows.Err(); err != nil { - return []*model.Message{}, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err)) + return []*model.Message{}, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err)) } } if cursor != 0 { @@ -102,10 +102,10 @@ func checkMembership(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (bo var isMember bool err := db.QueryRow(query, userID, conversationID).Scan(&isMember) if err != nil { - return false, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check membership: %w", err)) + return false, utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to check membership: %w", err)) } if !isMember { - return false, helpers.NewError(helpers.ErrForbidden, "You are member of the conversation", nil) + return false, utils.NewError(utils.ErrForbidden, "You are member of the conversation", nil) } return isMember, nil } @@ -121,7 +121,7 @@ func DeleteMessage(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, messa var messageOwnerID uuid.UUID err := db.QueryRow(checkMessageOwnershipQuery, messageID).Scan(&messageOwnerID) if err != nil { - return helpers.NewError(helpers.ErrInternal, "Failed to delete message", fmt.Errorf("failed to check message ownership: %w", err)) + return utils.NewError(utils.ErrInternal, "Failed to delete message", fmt.Errorf("failed to check message ownership: %w", err)) } var isSelfMessage bool @@ -134,20 +134,51 @@ func DeleteMessage(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, messa return err } if !isSelfMessage && !isAdmin { - return helpers.NewError(helpers.ErrForbidden, "You don't have permissions to delete that message ", nil) + return utils.NewError(utils.ErrForbidden, "You don't have permissions to delete that message ", nil) } row, err := db.Exec(deleteMessageQuery, messageID) if err != nil { - return helpers.NewError(helpers.ErrInternal, "Failed to delete message", fmt.Errorf("failed to delete message: %w", err)) + return utils.NewError(utils.ErrInternal, "Failed to delete message", fmt.Errorf("failed to delete message: %w", err)) } rowsAffected, err := row.RowsAffected() if err != nil { - return helpers.NewError(helpers.ErrInternal, "Failed to delete message", fmt.Errorf("failed to get rows affected: %w", err)) + return utils.NewError(utils.ErrInternal, "Failed to delete message", fmt.Errorf("failed to get rows affected: %w", err)) } if rowsAffected == 0 { - return helpers.NewError(helpers.ErrNotFound, "Message not found", nil) + return utils.NewError(utils.ErrNotFound, "Message not found", nil) } return nil } + +func InsertMessage(db *sql.DB, senderID uuid.UUID, conversationID uuid.UUID, message string, attachmentUrls []string) (*model.Message, error) { + isMember, err := checkMembership(db, senderID, conversationID) + if err != nil { + return &model.Message{}, fmt.Errorf("failed to check membership: %w", err) + } + if !isMember { + return &model.Message{}, fmt.Errorf("user is not a member of the conversation") + } + + query := ` + INSERT INTO Messages (conversation_id, user_id, content, attachment_urls) + VALUES ($1, $2, $3, $4) + RETURNING message_id, content AS message, sent_at, attachment_urls, user_id AS sender_id, conversation_id; + ` + + var msg model.Message + + err = db.QueryRow(query, conversationID, senderID, message, attachmentUrls).Scan( + &msg.ID, + &msg.Message, + &msg.SentAt, + &msg.AttachmentUrl, + &msg.SenderID, + &msg.ConversationID, + ) + if err != nil { + return &model.Message{}, fmt.Errorf("failed to insert message: %w", err) + } + return &msg, nil +} diff --git a/go.mod b/go.mod index c6cdae7..8aadad8 100644 --- a/go.mod +++ b/go.mod @@ -12,16 +12,25 @@ require ( require ( github.com/MicahParks/keyfunc/v2 v2.1.0 // indirect github.com/andybalholm/brotli v1.1.1 // indirect + github.com/fasthttp/websocket v1.5.12 // indirect github.com/gofiber/contrib/jwt v1.0.10 // indirect + github.com/gofiber/contrib/socketio v1.1.4 // indirect + github.com/gofiber/contrib/websocket v1.3.3 // indirect + github.com/gofrs/uuid v4.0.0+incompatible // indirect github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/gomodule/redigo v1.8.4 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/googollee/go-socket.io v1.7.0 // indirect + github.com/gorilla/websocket v1.4.2 // indirect github.com/klauspost/compress v1.17.11 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.58.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect - golang.org/x/sys v0.29.0 // indirect + golang.org/x/net v0.34.0 // indirect + golang.org/x/sys v0.30.0 // indirect ) diff --git a/go.sum b/go.sum index a832a39..dd7ff18 100644 --- a/go.sum +++ b/go.sum @@ -4,14 +4,30 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1 github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE= +github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg= github.com/gofiber/contrib/jwt v1.0.10 h1:/ilGepl6i0Bntl0Zcd+lAzagY8BiS1+fEiAj32HMApk= github.com/gofiber/contrib/jwt v1.0.10/go.mod h1:1qBENE6sZ6PPT4xIpBzx1VxeyROQO7sj48OlM1I9qdU= +github.com/gofiber/contrib/socketio v1.1.4 h1:XoS4N4yvbVJeFOfzFOiHKRGn++Vax+doQhJLZEMnc5M= +github.com/gofiber/contrib/socketio v1.1.4/go.mod h1:ZqUgo7SYEp7TJFH+BwpW2F0Xo4BdIYLLPkbxxnK3KnY= +github.com/gofiber/contrib/websocket v1.3.3 h1:R6DlDKieGPMiDrqYNyobsHbvjqvxMHeCj/lLaca4jg8= +github.com/gofiber/contrib/websocket v1.3.3/go.mod h1:07u6QGMsvX+sx7iGNCl5xhzuUVArWwLQ3tBIH24i+S8= github.com/gofiber/fiber/v2 v2.52.6 h1:Rfp+ILPiYSvvVuIPvxrBns+HJp8qGLDnLJawAu27XVI= github.com/gofiber/fiber/v2 v2.52.6/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/gomodule/redigo v1.8.4 h1:Z5JUg94HMTR1XpwBaSH4vq3+PNSIykBLxMdglbw10gg= +github.com/gomodule/redigo v1.8.4/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googollee/go-socket.io v1.7.0 h1:ODcQSAvVIPvKozXtUGuJDV3pLwdpBLDs1Uoq/QHIlY8= +github.com/googollee/go-socket.io v1.7.0/go.mod h1:0vGP8/dXR9SZUMMD4+xxaGo/lohOw3YWMh2WRiWeKxg= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= @@ -29,10 +45,15 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc= +github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= @@ -46,9 +67,16 @@ golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/handlers/auth.go b/handlers/auth.go index ec4b580..b2ff0cd 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -9,7 +9,7 @@ import ( "os" "relay-server/config" "relay-server/database" - "relay-server/helpers" + "relay-server/utils" "time" ) @@ -21,21 +21,21 @@ func Signup(c *fiber.Ctx) error { u := new(SignupStruct) if err := c.BodyParser(u); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid request body", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid request body", err) } // Validate input if u.Username == "" { - return helpers.NewError(helpers.ErrInvalidInput, "Username is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "Username is empty", nil) } if u.Password == "" { - return helpers.NewError(helpers.ErrInvalidInput, "Password is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "Password is empty", nil) } - if !helpers.IsValidPassword(u.Password) { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid password", nil) + if !utils.IsValidPassword(u.Password) { + return utils.NewError(utils.ErrInvalidInput, "Invalid password", nil) } - if !helpers.IsValidUsername(u.Username) { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid username", nil) + if !utils.IsValidUsername(u.Username) { + return utils.NewError(utils.ErrInvalidInput, "Invalid username", nil) } // Check if user exists @@ -44,13 +44,13 @@ func Signup(c *fiber.Ctx) error { return err } if exist { - return helpers.NewError(helpers.ErrInvalidInput, "User already exists", nil) + return utils.NewError(utils.ErrInvalidInput, "User already exists", nil) } // Create password hash passwordHash, err := bcrypt.GenerateFromPassword([]byte(u.Password), config.BCRYPT_COST) if err != nil { - return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to generate password hash: %w", err)) + return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to generate password hash: %w", err)) } // Insert user @@ -66,7 +66,7 @@ func Signup(c *fiber.Ctx) error { }) signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET"))) if err != nil { - return helpers.NewError(helpers.ErrInternal, "Failed to generate auth token", err) + return utils.NewError(utils.ErrInternal, "Failed to generate auth token", err) } // Set token cookie @@ -91,21 +91,21 @@ func Login(c *fiber.Ctx) error { u := new(loginStruct) if err := c.BodyParser(u); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid request body", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid request body", err) } // Validate input if u.Username == "" { - return helpers.NewError(helpers.ErrInvalidInput, "Username is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "Username is empty", nil) } if u.Password == "" { - return helpers.NewError(helpers.ErrInvalidInput, "Password is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "Password is empty", nil) } - if !helpers.IsValidUsername(u.Username) { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid username", nil) + if !utils.IsValidUsername(u.Username) { + return utils.NewError(utils.ErrInvalidInput, "Invalid username", nil) } - if !helpers.IsValidPassword(u.Password) { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid password", nil) + if !utils.IsValidPassword(u.Password) { + return utils.NewError(utils.ErrInvalidInput, "Invalid password", nil) } // Check if user exists @@ -114,7 +114,7 @@ func Login(c *fiber.Ctx) error { return err } if !exist { - return helpers.NewError(helpers.ErrNotFound, "User does not exist", nil) + return utils.NewError(utils.ErrNotFound, "User does not exist", nil) } // Verify password @@ -123,7 +123,7 @@ func Login(c *fiber.Ctx) error { return err } if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(u.Password)); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid password", nil) + return utils.NewError(utils.ErrInvalidInput, "Invalid password", nil) } // Get user ID @@ -139,7 +139,7 @@ func Login(c *fiber.Ctx) error { }) signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET"))) if err != nil { - return helpers.NewError(helpers.ErrInternal, "Failed to generate token", err) + return utils.NewError(utils.ErrInternal, "Failed to generate token", err) } // Set token cookie @@ -159,13 +159,13 @@ func Login(c *fiber.Ctx) error { func ValidateToken(c *fiber.Ctx) error { username, ok := c.Locals("username").(string) if !ok { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid token: missing username", fmt.Errorf("missing username: %v", c.Locals("username"))) + return utils.NewError(utils.ErrInvalidInput, "Invalid token: missing username", fmt.Errorf("missing username: %v", c.Locals("username"))) } userIDVal := c.Locals("userID") userID, ok := userIDVal.(uuid.UUID) if !ok { - return helpers.NewError(helpers.ErrUnauthorized, "unauthorized", fmt.Errorf("missing/invalid userID type: %T, value: %v\n", userIDVal, userIDVal)) + return utils.NewError(utils.ErrUnauthorized, "unauthorized", fmt.Errorf("missing/invalid userID type: %T, value: %v\n", userIDVal, userIDVal)) } return c.Status(fiber.StatusOK).JSON(fiber.Map{ diff --git a/handlers/contacts.go b/handlers/contacts.go index 55cdec5..3dc2787 100644 --- a/handlers/contacts.go +++ b/handlers/contacts.go @@ -4,7 +4,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/google/uuid" "relay-server/database" - "relay-server/helpers" + "relay-server/utils" ) func DeleteContact(c *fiber.Ctx) error { @@ -15,14 +15,14 @@ func DeleteContact(c *fiber.Ctx) error { p := new(params) if err := c.ParamsParser(p); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid params", err) } if p.ContactID == uuid.Nil { - return helpers.NewError(helpers.ErrInvalidInput, "contact ID is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "contact ID is empty", nil) } if p.ConversationID == uuid.Nil { - return helpers.NewError(helpers.ErrInvalidInput, "conversation ID is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "conversation ID is empty", nil) } err := database.DeleteContact(database.DB, p.ContactID, p.ConversationID) @@ -42,15 +42,15 @@ func InsertContact(c *fiber.Ctx) error { p := new(params) if err := c.ParamsParser(p); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid params", err) } if p.ContactUsername == "" { - return helpers.NewError(helpers.ErrInvalidInput, "contact username is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "contact username is empty", nil) } - if !helpers.IsValidUsername(p.ContactUsername) { - return helpers.NewError(helpers.ErrInvalidInput, "invalid username", nil) + if !utils.IsValidUsername(p.ContactUsername) { + return utils.NewError(utils.ErrInvalidInput, "invalid username", nil) } contactID, err := database.GetUserID(database.DB, p.ContactUsername) @@ -84,10 +84,10 @@ func GetContactSuggestions(c *fiber.Ctx) error { p := new(params) if err := c.ParamsParser(p); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid params", err) } if p.ContactUsername == "" { - return helpers.NewError(helpers.ErrInvalidInput, "contact username is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "contact username is empty", nil) } suggestions, err := database.ContactSuggestion(database.DB, p.ContactUsername) diff --git a/handlers/groups.go b/handlers/groups.go index 9ec18a2..4083630 100644 --- a/handlers/groups.go +++ b/handlers/groups.go @@ -6,7 +6,7 @@ import ( "github.com/google/uuid" "log" "relay-server/database" - "relay-server/helpers" + "relay-server/utils" ) func CreateGroup(c *fiber.Ctx) error { @@ -18,11 +18,11 @@ func CreateGroup(c *fiber.Ctx) error { var req createGroupRequest if err := c.BodyParser(&req); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid request body", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid request body", err) } if req.GroupName == "" { - return helpers.NewError(helpers.ErrInvalidInput, "Group name is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "Group name is empty", nil) } groupID, err := database.CreateGroup(database.DB, req.GroupName, userID) @@ -41,13 +41,13 @@ func AddMemberToGroup(c *fiber.Ctx) error { } var req addMemberToGroupRequest if err := c.BodyParser(&req); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid request body", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid request body", err) } if req.GroupID == uuid.Nil { - return helpers.NewError(helpers.ErrInvalidInput, "Group ID is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "Group ID is empty", nil) } if req.UserID == uuid.Nil { - return helpers.NewError(helpers.ErrInvalidInput, "User ID is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "User ID is empty", nil) } isAdmin, err := database.IsAdmin(database.DB, req.UserID, req.GroupID) @@ -55,7 +55,7 @@ func AddMemberToGroup(c *fiber.Ctx) error { return err } if !isAdmin { - return helpers.NewError(helpers.ErrUnauthorized, "You are not a group administrator", nil) + return utils.NewError(utils.ErrUnauthorized, "You are not a group administrator", nil) } _, err = database.AddMemberToGroup(database.DB, req.GroupID, req.UserID) @@ -75,15 +75,15 @@ func GetMembers(c *fiber.Ctx) error { var p params if err := c.ParamsParser(&p); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid params", err) } - isMember, err := database.IsMember(database.DB, p.GroupID, c.Locals("userID").(uuid.UUID)) + isMember, err := database.IsMember(database.DB, c.Locals("userID").(uuid.UUID), p.GroupID) if err != nil { - return helpers.NewError(helpers.ErrInternal, "Failed to get members", fmt.Errorf("failed to check if user is a member: %w", err)) + return utils.NewError(utils.ErrInternal, "Failed to get members", fmt.Errorf("failed to check if user is a member: %w", err)) } if !isMember { - return helpers.NewError(helpers.ErrForbidden, "You are not a member of this group", nil) + return utils.NewError(utils.ErrForbidden, "You are not a member of this group", nil) } members, err := database.GetMembers(database.DB, p.GroupID) @@ -91,5 +91,5 @@ func GetMembers(c *fiber.Ctx) error { return err } - return c.Status(fiber.StatusOK).JSON(fiber.Map{"members": members}) + return c.Status(fiber.StatusOK).JSON(members) } diff --git a/handlers/messages.go b/handlers/messages.go index 1e16749..2c5c58a 100644 --- a/handlers/messages.go +++ b/handlers/messages.go @@ -4,7 +4,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/google/uuid" "relay-server/database" - "relay-server/helpers" + "relay-server/utils" ) func GetMessages(c *fiber.Ctx) error { @@ -20,13 +20,13 @@ func GetMessages(c *fiber.Ctx) error { p := new(params) q := new(query) if err := c.ParamsParser(p); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid params", err) } if err := c.QueryParser(q); err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid query", err) + return utils.NewError(utils.ErrInvalidInput, "Invalid query", err) } if p.conversationID == uuid.Nil { - return helpers.NewError(helpers.ErrInvalidInput, "conversation ID is empty", nil) + return utils.NewError(utils.ErrInvalidInput, "conversation ID is empty", nil) } messages, err := database.GetMessages(database.DB, userID, p.conversationID, q.limit, q.cursor) diff --git a/main.go b/main.go index 31086ba..a702fb0 100644 --- a/main.go +++ b/main.go @@ -1,17 +1,27 @@ package main import ( + "github.com/gofiber/contrib/websocket" "github.com/gofiber/fiber/v2" "log" "relay-server/database" - "relay-server/helpers" "relay-server/router" + "relay-server/utils" ) func main() { app := fiber.New(fiber.Config{ - ErrorHandler: helpers.ErrorHandler, + ErrorHandler: utils.ErrorHandler, }) + + app.Use(func(c *fiber.Ctx) error { + if websocket.IsWebSocketUpgrade(c) { + c.Locals("allowed", true) + return c.Next() + } + return fiber.ErrUpgradeRequired + }) + db, err := database.Init() if err != nil { log.Fatal("Failed to initialize database") diff --git a/middleware/protected.go b/middleware/protected.go index e60cb86..a55fc95 100644 --- a/middleware/protected.go +++ b/middleware/protected.go @@ -7,8 +7,8 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "os" - "relay-server/helpers" "relay-server/model" + "relay-server/utils" ) func Protected() fiber.Handler { @@ -22,7 +22,7 @@ func Protected() fiber.Handler { claims := user.Claims.(*model.UserClaims) userID, err := uuid.Parse(claims.UserID) if err != nil { - return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to parse user ID: %w", err)) + return utils.NewError(utils.ErrInternal, "internal server error", fmt.Errorf("failed to parse user ID: %w", err)) } c.Locals("userID", userID) c.Locals("username", claims.Username) diff --git a/model/model.go b/model/model.go index d539b07..28036f0 100644 --- a/model/model.go +++ b/model/model.go @@ -30,11 +30,13 @@ type ContactSuggestion struct { } type Message struct { - MessageID int `json:"message_id"` - Message string `json:"message"` - SentAt time.Time `json:"sent_at"` - Sender string `json:"sender"` - AttachmentUrl *string `json:"attachment_url"` + ID int `json:"message_id"` + Message string `json:"message"` + SentAt time.Time `json:"sent_at"` + Sender string `json:"sender"` + SenderID uuid.UUID `json:"sender_id"` + AttachmentUrl *string `json:"attachment_url"` + ConversationID uuid.UUID `json:"conversation_id"` } type CreateGroupResponse struct { diff --git a/router/router.go b/router/router.go index b80c0a6..d3ff93f 100644 --- a/router/router.go +++ b/router/router.go @@ -8,6 +8,7 @@ import ( ) func SetupRoutes(app *fiber.App) { + app.Get("/", middleware.Protected(), func(c *fiber.Ctx) error { return c.SendString("Hello, World!") }) @@ -36,4 +37,8 @@ func SetupRoutes(app *fiber.App) { groups.Post("/create", handlers.CreateGroup) groups.Post("/addMember", handlers.AddMemberToGroup) groups.Get("/getMembers/:groupID", handlers.GetMembers) + + // Socket group + socket := chat.Group("/ws", middleware.Protected(), logger.New()) + socket.Get("/:id") } diff --git a/socket/socket.go b/socket/socket.go new file mode 100644 index 0000000..4d283b4 --- /dev/null +++ b/socket/socket.go @@ -0,0 +1,258 @@ +package socket + +import ( + "fmt" + "github.com/google/uuid" + "log" + "relay-server/database" + "relay-server/utils" + + socketio "github.com/googollee/go-socket.io" +) + +// Message represents the chat message structure +type Message struct { + Content string `json:"message"` + Recipient string `json:"recipient"` + RecipientID uuid.UUID `json:"recipient_id"` + AttachmentURLs []string `json:"attachment_urls"` +} + +// DeleteMessageData represents the message deletion structure +type DeleteMessageData struct { + ConversationID uuid.UUID `json:"conversation_id"` + MessageID int `json:"message_id"` +} + +// GroupUserData represents group-user operations data +type GroupUserData struct { + GroupID uuid.UUID `json:"group_id"` + UserID uuid.UUID `json:"user_id"` +} + +// MessageReadData represents message read status data +type MessageReadData struct { + ConversationID uuid.UUID `json:"conversation_id"` + MessageID int `json:"message_id"` +} + +// SocketResponse represents a standard socket response +type SocketResponse struct { + Status string `json:"status"` + Message string `json:"message"` +} + +// InitializeSocket sets up and configures the Socket.IO server +func InitializeSocket() (*socketio.Server, error) { + server := socketio.NewServer(nil) + + // Middleware for authentication + server.OnConnect("/", func(s socketio.Conn) error { + token := s.RemoteHeader().Get("Authorization") + if token == "" { + log.Println("(socket) Not logged in") + return fmt.Errorf("not logged in") + } + + c, err := utils.ValidateToken(token) + if err != nil { + log.Printf("(socket) Token verification failed: %v\n", err) + return fmt.Errorf("invalid token") + } + + if !utils.IsValidUsername(c.Username) { + log.Println("(socket) Invalid username") + return fmt.Errorf("invalid username") + } + + s.SetContext(map[string]interface{}{ + "username": c.Username, + "user_id": c.UserID, + }) + + log.Printf("(socket) socket id: %s, username: %s, user_id: %s\n", s.ID(), c.Username, c.UserID) + return nil + }) + + // Handle connection + server.OnEvent("/", "connection", func(s socketio.Conn) { + ctx := s.Context().(map[string]interface{}) + username := ctx["username"].(string) + userID := ctx["user_id"].(string) + + if !utils.IsValidUsername(username) { + s.Close() + return + } + + conversations, err := database.GetUserConversations(database.DB, userID) + if err != nil { + log.Printf("(socket) Failed to get user conversations: %v\n", err) + return + } + + // Join all conversations + for _, conv := range conversations { + s.Join(conv) + } + s.Join(userID) // Join user's personal room + log.Printf("User: %s joined to: %v\n", username, conversations) + }) + + // Handle chat message + server.OnEvent("/", "chat message", func(s socketio.Conn, msg Message) SocketResponse { + ctx := s.Context().(map[string]interface{}) + username := ctx["username"].(string) + userIDstr := ctx["user_id"].(string) + conversationIDstr := ctx["recipient_id"].(string) + userID, err := uuid.Parse(userIDstr) + if err != nil { + return SocketResponse{Status: "error", Message: "Invalid user id"} + } + conversationID, err := uuid.Parse(conversationIDstr) + if err != nil { + return SocketResponse{Status: "error", Message: "Invalid conversation id"} + } + + if msg.Content == "" && len(msg.AttachmentURLs) == 0 { + return SocketResponse{Status: "error", Message: "No message or attachment provided"} + } + if msg.Recipient == "" { + return SocketResponse{Status: "error", Message: "No recipient provided"} + } + + insertedMsg, err := database.InsertMessage(database.DB, userID, conversationID, msg.Content, msg.AttachmentURLs) + if err != nil { + log.Printf("(socket) Failed to insert message: %v\n", err) + return SocketResponse{Status: "error", Message: "Failed to insert message"} + } + + // Emit message to recipients + server.BroadcastToRoom("", msg.Recipient, "chat message", map[string]interface{}{ + "sender": username, + "message": insertedMsg.Message, + "attachment_urls": msg.AttachmentURLs, + "recipient": msg.Recipient, + "message_id": insertedMsg.ID, + "sender_id": userID, + "sent_at": insertedMsg.SentAt, + "conversation_id": insertedMsg.ConversationID, + }) + + return SocketResponse{Status: "ok", Message: "Received message"} + }) + + // Handle delete message + server.OnEvent("/", "delete message", func(s socketio.Conn, data DeleteMessageData) SocketResponse { + ctx := s.Context().(map[string]interface{}) + userIDstr := ctx["user_id"].(string) + + userID, err := uuid.Parse(userIDstr) + if err != nil { + return SocketResponse{Status: "error", Message: "Invalid user id"} + } + + if data.MessageID == 0 { + return SocketResponse{Status: "error", Message: "No message id provided"} + } + if data.ConversationID == uuid.Nil { + return SocketResponse{Status: "error", Message: "No conversation id provided"} + } + + err = database.DeleteMessage(database.DB, userID, data.ConversationID, data.MessageID) + if err != nil { + return SocketResponse{Status: "error", Message: err.Error()} + } + + conversationIDstr := data.ConversationID.String() + server.BroadcastToRoom("", conversationIDstr, "delete message", data) + return SocketResponse{Status: "ok", Message: "Successfully deleted message"} + }) + + // Handle remove user from group + server.OnEvent("/", "remove user from group", func(s socketio.Conn, data GroupUserData) SocketResponse { + if data.GroupID == uuid.Nil { + return SocketResponse{Status: "error", Message: "No group id provided"} + } + if data.UserID == uuid.Nil { + return SocketResponse{Status: "error", Message: "No user id provided"} + } + + err := removeUserFromGroupById(data.GroupID, data.UserID) + if err != nil { + return SocketResponse{Status: "error", Message: err.Error()} + } + + // Remove user from room + sockets := server.Sockets(data.GroupID) + for _, socket := range sockets { + if socket.Context().(map[string]interface{})["user_id"] == data.UserID { + socket.Leave(data.GroupID) + } + } + + server.BroadcastToRoom("", data.GroupID, "left group", data) + server.BroadcastToRoom("", data.UserID, "left group", data) + + return SocketResponse{Status: "ok", Message: "Successfully removed user from group"} + }) + + // Handle administrator operations + server.OnEvent("/", "added administrator", func(s socketio.Conn, data GroupUserData) SocketResponse { + ctx := s.Context().(map[string]interface{}) + userID := ctx["user_id"].(string) + + if data.GroupID == "" { + return SocketResponse{Status: "error", Message: "No conversation id provided"} + } + if data.UserID == "" { + return SocketResponse{Status: "error", Message: "No user id provided"} + } + + isAdmin, err := isAdmin(userID, data.GroupID) + if err != nil || !isAdmin { + return SocketResponse{Status: "error", Message: "You are not an administrator"} + } + + err = addAdministrator(data.GroupID, data.UserID, userID) + if err != nil { + return SocketResponse{Status: "error", Message: err.Error()} + } + + server.BroadcastToRoom("", data.GroupID, "added administrator", data) + return SocketResponse{Status: "ok", Message: "Successfully added administrator"} + }) + + // Handle message read status + server.OnEvent("/", "message read", func(s socketio.Conn, data MessageReadData) { + ctx := s.Context().(map[string]interface{}) + userID := ctx["user_id"].(string) + + if data.ConversationID == "" || data.MessageID == "" { + return + } + + err := updateContactStatus(userID, data.ConversationID, data.MessageID) + if err != nil { + log.Printf("Failed to update message read status: %v\n", err) + } + }) + + // Handle disconnection + server.OnDisconnect("/", func(s socketio.Conn, reason string) { + log.Printf("(socket) %s disconnected due to: %s\n", s.ID(), reason) + }) + + return server, nil +} + +// The following functions would need to be implemented according to your specific needs: +// - verifyJwtToken +// - isValidUsername +// - getConversationsForUser +// - insertMessage +// - deleteMessage +// - removeUserFromGroupById +// - isAdmin +// - addAdministrator +// - updateContactStatus diff --git a/helpers/errorHandler.go b/utils/errorHandler.go similarity index 98% rename from helpers/errorHandler.go rename to utils/errorHandler.go index bb7a650..0283df5 100644 --- a/helpers/errorHandler.go +++ b/utils/errorHandler.go @@ -1,4 +1,4 @@ -package helpers +package utils import ( "errors" diff --git a/helpers/filter.go b/utils/filter.go similarity index 98% rename from helpers/filter.go rename to utils/filter.go index bb0bd60..f6700b8 100644 --- a/helpers/filter.go +++ b/utils/filter.go @@ -1,4 +1,4 @@ -package helpers +package utils import ( "regexp" diff --git a/utils/token.go b/utils/token.go new file mode 100644 index 0000000..9c97bd9 --- /dev/null +++ b/utils/token.go @@ -0,0 +1,54 @@ +package utils + +import ( + "errors" + "fmt" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "os" + "relay-server/model" +) + +func GenerateToken(userID uuid.UUID, username string) (string, error) { + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": userID, + "username": username, + }) + signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET"))) + if err != nil { + return "", NewError(ErrInternal, "Failed to generate token", err) + } + return signedToken, nil +} + +func ValidateToken(tokenString string) (*model.UserClaims, error) { + secretKey := os.Getenv("JWT_SECRET") + + // Parse the token + token, err := jwt.ParseWithClaims(tokenString, &model.UserClaims{}, func(token *jwt.Token) (interface{}, error) { + // Validate the signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(secretKey), nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to parse token: %v", err) + } + + // Check if the token is valid + if !token.Valid { + return nil, errors.New("invalid token") + } + + // Type assert the claims + claims, ok := token.Claims.(*model.UserClaims) + if !ok { + return nil, errors.New("failed to parse claims") + } + + return claims, nil +}