diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml
new file mode 100644
index 0000000..321d912
--- /dev/null
+++ b/.idea/dataSources.xml
@@ -0,0 +1,17 @@
+
+
+
+
+ postgresql
+ true
+ org.postgresql.Driver
+ jdbc:postgresql://192.168.0.47:5432/relay
+
+
+
+
+
+ $ProjectFileDir$
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/database/auth.go b/database/auth.go
index 5e2a64f..d7bf66d 100644
--- a/database/auth.go
+++ b/database/auth.go
@@ -3,6 +3,7 @@ package database
import (
"database/sql"
"fmt"
+ "github.com/google/uuid"
_ "github.com/lib/pq"
"log"
)
@@ -34,7 +35,7 @@ func CheckUserExists(db *sql.DB, username string) (bool, error) {
err := db.QueryRow(query, username).Scan(&count)
if err != nil {
- return false, fmt.Errorf("error checking username exists: %v", err)
+ return false, fmt.Errorf("error checking username exists: %w", err)
}
return count > 0, err
@@ -47,15 +48,15 @@ func InsertUser(db *sql.DB, username string, passwordHash string) (string, error
RETURNING user_id;
`
- var userId string
- err := db.QueryRow(query, username, passwordHash).Scan(&userId)
+ var userID string
+ err := db.QueryRow(query, username, passwordHash).Scan(&userID)
if err != nil {
- return "", fmt.Errorf("error inserting user: %v", err)
+ return "", fmt.Errorf("error inserting user: %w", err)
}
- log.Printf("Inserted user: %v", username)
+ fmt.Printf("Inserted user: %w", username)
- return userId, err
+ return userID, err
}
func GetPasswordHash(db *sql.DB, username string) (string, error) {
@@ -67,24 +68,22 @@ func GetPasswordHash(db *sql.DB, username string) (string, error) {
var passwordHash string
err := db.QueryRow(query, username).Scan(&passwordHash)
if err != nil {
- fmt.Printf("error getting password hash: %v\n", err)
- return "", fmt.Errorf("error getting password hash: %v", err)
+ return "", fmt.Errorf("error getting password hash: %w", err)
}
return passwordHash, err
}
-func GetUserId(db *sql.DB, username string) (string, error) {
+func GetUserID(db *sql.DB, username string) (uuid.UUID, error) {
query := `
SELECT user_id FROM Accounts
- WHERE LOWER(username) = $1;
+ WHERE LOWER(username) = LOWER($1);
`
- var dbUsername string
- err := db.QueryRow(query, username).Scan(&dbUsername)
+ var userID uuid.UUID
+ err := db.QueryRow(query, username).Scan(&userID)
if err != nil {
- fmt.Printf("Error getting password: %v\n", err)
- return "", fmt.Errorf("error getting user id: %v", err)
+ return uuid.Nil, fmt.Errorf("error getting user id: %w", err)
}
- return dbUsername, err
+ return userID, nil
}
diff --git a/database/contacts.go b/database/contacts.go
index 9e5c67a..5b8cc8a 100644
--- a/database/contacts.go
+++ b/database/contacts.go
@@ -5,7 +5,7 @@ import (
"errors"
"fmt"
"github.com/google/uuid"
- "log"
+ "relay-server/model"
)
func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (string, error) {
@@ -58,8 +58,6 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (stri
if rowsAffected == 0 {
return "", fmt.Errorf("no matching membership found with conversation id: %s, user id: %s", conversationID, userID)
}
-
- log.Printf("Successfully removed user %s from group %s", userID, conversationID)
} else {
// Handle direct conversation
res, err := db.Exec(
@@ -68,21 +66,198 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (stri
conversationID,
)
if err != nil {
- log.Printf("Error deleting contact: %v", err)
return "", fmt.Errorf("error deleting contact: %w", err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
- log.Printf("Error checking contact deletion: %v", err)
return "", fmt.Errorf("error checking contact deletion: %w", err)
}
if rowsAffected == 0 {
return fmt.Sprintf("no matching contact found with user id: %s, conversation id: %s", userID, conversationID), nil
}
- log.Printf("Successfully deleted contact for user %s in conversation %s", userID, conversationID)
+ fmt.Printf("Successfully deleted contact for user %s in conversation %s", userID, conversationID)
}
return "", nil
}
+
+func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUsername string) (*model.Contact, error) {
+ isSelfContact := userID == contactID
+ var contact model.Contact
+ var conversationID uuid.UUID
+
+ if isSelfContact {
+ findSelfConversationQuery := `
+ SELECT c.conversation_id
+ FROM Conversations c
+ JOIN Memberships m ON c.conversation_id = m.conversation_id
+ WHERE c.conversation_type = 'direct'
+ AND m.user_id = $1
+ AND (
+ SELECT COUNT(*)
+ FROM Memberships
+ WHERE conversation_id = c.conversation_id
+ ) = 1
+ LIMIT 1;
+ `
+
+ err := db.QueryRow(findSelfConversationQuery, userID).Scan(&conversationID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ conversationID = uuid.Nil
+ }
+ return nil, fmt.Errorf("error finding existing conversation: %w", err)
+ }
+
+ // if conversation for themselves don't exist
+ if conversationID == uuid.Nil {
+ createConversationQuery := `
+ INSERT INTO Conversations (conversation_type)
+ VALUES ('direct')
+ RETURNING conversation_id;
+ `
+ err := db.QueryRow(createConversationQuery).Scan(&conversationID)
+ if err != nil {
+ return nil, fmt.Errorf("error creating conversation for self-contact: %w", err)
+ }
+
+ createMembershipQuery := `
+ INSERT INTO Memberships (conversation_id, user_id)
+ VALUES ($1, $2)
+ ON CONFLICT (conversation_id, user_id) DO NOTHING;
+ `
+ _, err = db.Exec(createMembershipQuery, conversationID, userID)
+
+ } else {
+ // For regular contacts, check if a conversation already exists between the two users
+ findConversationQuery := `
+ SELECT c.conversation_id
+ FROM Conversations c
+ JOIN Memberships m1 ON c.conversation_id = m1.conversation_id
+ JOIN Memberships m2 ON c.conversation_id = m2.conversation_id
+ WHERE c.conversation_type = 'direct'
+ AND (
+ (m1.user_id = $1 AND m2.user_id = $2)
+ OR
+ (m1.user_id = $2 AND m2.user_id = $1)
+ )
+ LIMIT 1;
+ `
+ err := db.QueryRow(findConversationQuery, userID, contactID).Scan(&conversationID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ conversationID = uuid.Nil
+ }
+ return nil, fmt.Errorf("error finding existing conversation: %w", err)
+ }
+ if conversationID != uuid.Nil {
+ // Create a new conversation between user and contact
+ createConversationQuery := `
+ INSERT INTO Conversations (conversation_type)
+ VALUES ('direct')
+ RETURNING conversation_id;
+ `
+ err := db.QueryRow(createConversationQuery).Scan(&conversationID)
+ if err != nil {
+ return nil, fmt.Errorf("error creating conversation: %w", err)
+ }
+
+ createMembershipQuery := `
+ INSERT INTO Memberships (conversation_id, user_id)
+ VALUES ($1, $2), ($1, $3)
+ ON CONFLICT (conversation_id, user_id) DO NOTHING;
+ `
+
+ res, err := db.Exec(createMembershipQuery, conversationID, userID, contactID)
+ if err != nil {
+ return nil, fmt.Errorf("error creating membership: %w", err)
+ }
+ rowsAffected, err := res.RowsAffected()
+ if err != nil {
+ return nil, fmt.Errorf("error checking membership creation: %w", err)
+ }
+ if rowsAffected == 0 {
+ return nil, fmt.Errorf("error creating membership %w", err)
+ }
+ }
+ }
+
+ }
+ insertedContact, err := InsertContactByID(db, contactID, conversationID)
+ if err != nil || insertedContact.UserID == uuid.Nil {
+ return nil, fmt.Errorf("error inserting contact by id: %w", err)
+ }
+
+ latestMessage, err := GetLatestMessage(db, conversationID)
+ if err != nil {
+ return nil, fmt.Errorf("error getting latest message: %w", err)
+ }
+
+ contact = model.Contact{
+ ID: insertedContact.ID,
+ ConversationID: insertedContact.ConversationID,
+ UserID: insertedContact.UserID,
+ Username: contactUsername,
+ Type: "direct",
+ LastMessageID: latestMessage.LastMessageID,
+ LastMessage: latestMessage.LastMessage,
+ LastMessageTime: latestMessage.LastMessageTime,
+ LastMessageSender: latestMessage.LastMessageSender,
+ }
+ return &contact, nil
+}
+
+func InsertContactByID(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (*model.Contact, error) {
+ // First check if contact already exists
+ checkQuery := `
+ SELECT contact_id, conversation_id, user_id
+ FROM Contacts
+ WHERE user_id = $1 AND conversation_id = $2
+ `
+ var contact model.Contact
+ err := db.QueryRow(checkQuery, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID)
+ if err == nil {
+ return nil, fmt.Errorf("contact already exists")
+ } else if !errors.Is(err, sql.ErrNoRows) {
+ return nil, fmt.Errorf("error checking contact existence: %w", err)
+ }
+
+ insertQuery := `
+ INSERT INTO Contacts (user_id, conversation_id)
+ VALUES($1, $2)
+ RETURNING contact_id, conversation_id, user_id
+ `
+
+ err = db.QueryRow(insertQuery, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID)
+ if err != nil {
+ return nil, fmt.Errorf("error inserting contact: %w", err)
+ }
+ fmt.Printf("Successfully inserted contact by id: %v", conversationID)
+
+ return &contact, nil
+}
+
+func GetLatestMessage(db *sql.DB, conversationId uuid.UUID) (*model.Contact, error) {
+ var latestMessage model.Contact
+
+ query := `
+ SELECT DISTINCT ON (m.conversation_id)
+ m.message_id AS last_message_id,
+ m.content AS last_message,
+ m.sent_at AS last_message_time,
+ a.username AS last_message_sender
+ FROM Messages m
+ JOIN Accounts a ON m.user_id = a.user_id
+ WHERE m.conversation_id = $1
+ ORDER BY m.conversation_id, m.sent_at DESC
+ LIMIT 1;
+ `
+
+ err := db.QueryRow(query, conversationId).Scan(&latestMessage.LastMessageID, &latestMessage.LastMessage, &latestMessage.LastMessageTime, &latestMessage.LastMessageSender)
+ if err != nil {
+ return nil, fmt.Errorf("error getting latest message: %w", err)
+ }
+ return &latestMessage, nil
+}
diff --git a/handlers/auth.go b/handlers/auth.go
index 3c2b67a..c4df198 100644
--- a/handlers/auth.go
+++ b/handlers/auth.go
@@ -1,10 +1,10 @@
package handlers
import (
- "fmt"
"github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
+ "log"
"os"
"relay-server/config"
"relay-server/database"
@@ -49,20 +49,20 @@ func Signup(c *fiber.Ctx) error {
// Create password hash
passwordHash, err := bcrypt.GenerateFromPassword([]byte(u.Password), config.BCRYPT_COST)
if err != nil {
- fmt.Printf("error hashing password: %v\n", err)
+ log.Printf("error hashing password: %w\n", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "internal server error"})
}
// Insert username and password hash to database
- userId, err := database.InsertUser(db, u.Username, string(passwordHash))
+ userID, err := database.InsertUser(db, u.Username, string(passwordHash))
if err != nil {
- fmt.Printf("error inserting user: %v\n", err)
+ log.Print(err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
// Generate token with user id and username
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "user_id": userId,
+ "user_id": userID,
"username": u.Username,
})
// Sign token
@@ -77,7 +77,7 @@ func Signup(c *fiber.Ctx) error {
c.Cookie(tokenCookie)
// If everything went well sent username and user_id assigned by database
- return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully signed up", "username": u.Username, "user_id": userId})
+ return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully signed up", "username": u.Username, "user_id": userID})
}
func Login(c *fiber.Ctx) error {
@@ -117,19 +117,22 @@ func Login(c *fiber.Ctx) error {
}
// Verifies password matching
- passwordHash, _ := database.GetPasswordHash(db, u.Username)
+ passwordHash, err := database.GetPasswordHash(db, u.Username)
+ if err != nil {
+ log.Printf("error getting password: %w\n", err)
+ }
if bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(u.Password)) != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid password"})
}
- userId, err := database.GetUserId(db, u.Username)
+ userID, err := database.GetUserID(db, u.Username)
if err != nil {
- fmt.Printf("error getting user id: %v\n", err)
+ log.Print(err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "internal server error"})
}
// Generate token with user id and username
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "user_id": userId,
+ "user_id": userID,
"username": u.Username,
})
// Sign token
@@ -140,19 +143,20 @@ func Login(c *fiber.Ctx) error {
tokenCookie.Name = "token"
tokenCookie.Value = signedToken
tokenCookie.Expires = time.Now().Add(30 * 24 * time.Hour)
- //tokenCookie.HTTPOnly = true
c.Cookie(tokenCookie)
- return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully logged in", "username": u.Username, "user_id": userId})
+ return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully logged in", "username": u.Username, "user_id": userID})
}
func ValidateToken(c *fiber.Ctx) error {
- username := c.Locals("username")
- userId := c.Locals("user_id")
+ username := c.Locals("username").(string)
+ userID := c.Locals("userID").(string)
- if userId == nil || username == nil {
- fmt.Println("userId or username is nil")
- return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token"})
- }
- return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "authorized", "username": c.Locals("username").(string), "user_id": c.Locals("userId").(string)})
+ //log.Printf("userID: %v, username: %v", userID, username)
+ //if userID == "" || username == "" {
+ // log.Printf("userID or username is empty %v", c.Locals("username"))
+ // return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token"})
+ //}
+
+ return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "authorized", "username": username, "user_id": userID})
}
diff --git a/handlers/contacts.go b/handlers/contacts.go
index ff7614e..f280155 100644
--- a/handlers/contacts.go
+++ b/handlers/contacts.go
@@ -1,17 +1,21 @@
package handlers
import (
+ "database/sql"
+ "errors"
+ "fmt"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"log"
"relay-server/database"
+ "relay-server/helpers"
)
func DeleteContact(c *fiber.Ctx) error {
type params struct {
- ContactId uuid.UUID `params:"contact_id"`
- ConversationId uuid.UUID `params:"conversation_id"`
+ ContactID uuid.UUID `params:"contact_id"`
+ ConversationID uuid.UUID `params:"conversation_id"`
}
p := new(params)
@@ -21,22 +25,69 @@ func DeleteContact(c *fiber.Ctx) error {
db := database.DB
- if p.ContactId == uuid.Nil {
+ if p.ContactID == uuid.Nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "contact_id is empty"})
}
- if p.ConversationId == uuid.Nil {
+ if p.ConversationID == uuid.Nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "conversation_id is empty"})
}
- msg, err := database.DeleteContact(db, p.ContactId, p.ConversationId)
+ msg, err := database.DeleteContact(db, p.ContactID, p.ConversationID)
if err != nil {
log.Println(err)
- return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to delete contact"})
+ return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"})
}
if msg != "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": msg})
}
- log.Println("Contact deleted")
+ fmt.Println("Contact deleted")
return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Contact deleted"})
}
+
+func InsertContact(c *fiber.Ctx) error {
+ userID, err := uuid.Parse(c.Locals("userID").(string))
+ if err != nil {
+ log.Println(err)
+ return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"})
+ }
+ type params struct {
+ ContactUsername string `params:"contact_username"`
+ }
+
+ p := new(params)
+
+ if err := c.ParamsParser(p); err != nil {
+ return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid params"})
+ }
+
+ db := database.DB
+
+ if p.ContactUsername == "" {
+ return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "contact_username is empty"})
+ }
+
+ if !helpers.IsValidUsername(p.ContactUsername) {
+ return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "username is invalid"})
+ }
+
+ contactID, err := database.GetUserID(db, p.ContactUsername)
+
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "user does not exist"})
+ }
+ log.Println(err)
+ return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"})
+ }
+
+ newContacts, err := database.InsertContact(db, userID, contactID, p.ContactUsername)
+ if err != nil {
+ log.Println(err)
+ return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"})
+ }
+
+ log.Println("Contact added")
+ return c.Status(fiber.StatusOK).JSON(newContacts)
+
+}
diff --git a/middleware/protected.go b/middleware/protected.go
index 2065a4c..32e0158 100644
--- a/middleware/protected.go
+++ b/middleware/protected.go
@@ -18,7 +18,7 @@ func Protected() fiber.Handler {
user := c.Locals("user").(*jwt.Token)
claims := user.Claims.(*model.UserClaims)
- c.Locals("userId", claims.UserId)
+ c.Locals("userID", claims.UserID)
c.Locals("username", claims.Username)
return c.Next()
},
diff --git a/model/model.go b/model/model.go
index c60c6ac..5d3ac96 100644
--- a/model/model.go
+++ b/model/model.go
@@ -1,9 +1,23 @@
package model
-import "github.com/golang-jwt/jwt/v5"
+import (
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/google/uuid"
+)
type UserClaims struct {
Username string `json:"username"`
- UserId string `json:"user_id"`
+ UserID string `json:"user_id"`
jwt.RegisteredClaims
}
+type Contact struct {
+ ID int `json:"contact_id"`
+ ConversationID uuid.UUID `json:"conversation_id"`
+ UserID uuid.UUID `json:"user_id"`
+ Username string `json:"username"`
+ Type string `json:"type"`
+ LastMessageID int `json:"last_message_id"`
+ LastMessage string `json:"last_message"`
+ LastMessageTime string `json:"last_message_time"`
+ LastMessageSender string `json:"last_message_sender"`
+}
diff --git a/router/router.go b/router/router.go
index 550b180..35a931a 100644
--- a/router/router.go
+++ b/router/router.go
@@ -15,12 +15,13 @@ func SetupRoutes(app *fiber.App) {
chat := api.Group("/chat", middleware.Protected(), logger.New())
// Auth group
- auth := api.Group("/auth", middleware.Protected(), handlers.ValidateToken)
+ auth := api.Group("/auth", logger.New())
auth.Post("/signup", handlers.Signup)
auth.Post("/login", handlers.Login)
- auth.Get("/validate", handlers.ValidateToken)
+ auth.Get("/validate", middleware.Protected(), handlers.ValidateToken)
// Contacts group
contacts := chat.Group("/contacts", middleware.Protected(), logger.New())
contacts.Delete("/:contact_id/:conversation_id", handlers.DeleteContact)
+ contacts.Post("/:contact_username", handlers.InsertContact)
}