WebSockets - Part Two
Part 1, Chapter 7
In the previous chapter, we learned the basics of building apps with Django Channels. In this chapter, we’re going to use what we learned to build the server side of our ride-sharing app.
Joining the Driver Pool
The moment a driver logs into our app, he joins a pool of drivers that can accept requests from riders. We’ll test this by creating a driver and logging him in, sending a broadcast message to the driver group, and confirming that the driver receives the message.
Let’s start by editing our create_user()
function to take a group
parameter.
# tests/test_websocket.py
from django.contrib.auth.models import Group
@database_sync_to_async
def create_user( # changed
username,
password,
group='rider'
):
# Create user.
user = get_user_model().objects.create_user(
username=username,
password=password
)
# Create user group.
user_group, _ = Group.objects.get_or_create(name=group) # new
user.groups.add(user_group)
user.save()
# Create access token.
access = AccessToken.for_user(user)
return user, access
Next, let’s add a new test.
async def test_join_driver_pool(self, settings):
settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS
_, access = await create_user(
'test.user@example.com', 'pAssw0rd', 'driver'
)
communicator = WebsocketCommunicator(
application=application,
path=f'/taxi/?token={access}'
)
connected, _ = await communicator.connect()
message = {
'type': 'echo.message',
'data': 'This is a test message.',
}
channel_layer = get_channel_layer()
await channel_layer.group_send('drivers', message=message)
response = await communicator.receive_json_from()
assert response == message
await communicator.disconnect()
Let’s update our consumer to subscribe a user to the driver group if he is a driver. Make the following changes to the connect()
and disconnect()
methods on the consumer.
# trips/consumers.py
async def connect(self):
user = self.scope['user']
if user.is_anonymous:
await self.close()
else:
user_group = await self._get_user_group(user) # new
if user_group == 'driver':
await self.channel_layer.group_add(
group='drivers',
channel=self.channel_name
)
await self.accept()
async def disconnect(self, code):
user = self.scope['user'] # new
user_group = await self._get_user_group(user)
if user_group == 'driver':
await self.channel_layer.group_discard(
group='drivers',
channel=self.channel_name
)
await super().disconnect(code)
Now, when the client establishes the WebSocket connection with the server, the server checks to see what group the authenticated user belongs to. If the user is a driver, then the function adds the user to the driver pool. When the WebSocket connection is closed, the server removes the user from the driver pool where appropriate.
For the last step, let’s add the _get_user_group()
helper function to the consumer.
# trips/consumers.py
from channels.db import database_sync_to_async
@database_sync_to_async
def _get_user_group(self, user):
return user.groups.first().name
Before you run the tests, delete the test_can_send_and_receive_broadcast_messages()
function. Our new test is essentially a duplicate of it that uses the drivers
group name instead of test
.
Run pytest
and confirm that all tests are passing.
8 passed in 1.26s
Requesting a Trip
When a rider requests a trip, the server creates a new Trip
record in the database and then broadcasts the rider’s request to all of the drivers in the driver pool.
Creating a Trip
It may not be obvious at first, but we need to expand our Trip
data model before we can proceed. Right now, it has no concept of which rider requested a trip and which driver accepted it.
Add a rider
and a driver
field to the Trip
model as shown below:
# trips/models.py
class Trip(models.Model):
REQUESTED = 'REQUESTED'
STARTED = 'STARTED'
IN_PROGRESS = 'IN_PROGRESS'
COMPLETED = 'COMPLETED'
STATUSES = (
(REQUESTED, REQUESTED),
(STARTED, STARTED),
(IN_PROGRESS, IN_PROGRESS),
(COMPLETED, COMPLETED),
)
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
created = models.DateTimeField(auto_now_add=True)
updated = models.DateTimeField(auto_now=True)
pick_up_address = models.CharField(max_length=255)
drop_off_address = models.CharField(max_length=255)
status = models.CharField(max_length=20, choices=STATUSES, default=REQUESTED)
driver = models.ForeignKey( # new
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.DO_NOTHING,
related_name='trips_as_driver'
)
rider = models.ForeignKey( # new
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.DO_NOTHING,
related_name='trips_as_rider'
)
def __str__(self):
return f'{self.id}'
def get_absolute_url(self):
return reverse('trip:trip_detail', kwargs={'trip_id': self.id})
Add the import:
from django.conf import settings
Make and run migrations to update our Trip
model’s database table:
(env)$ python manage.py makemigrations trips --name trip_driver_rider
(env)$ python manage.py migrate
While we’re at it, let’s update our admin page to reflect the changes we just made to our model.
# trips/admin.py
from django.contrib import admin
from django.contrib.auth.admin import UserAdmin as DefaultUserAdmin
from .models import Trip, User
@admin.register(User)
class UserAdmin(DefaultUserAdmin):
pass
@admin.register(Trip)
class TripAdmin(admin.ModelAdmin):
fields = ( # changed
'id', 'pick_up_address', 'drop_off_address', 'status',
'driver', 'rider',
'created', 'updated',
)
list_display = ( # changed
'id', 'pick_up_address', 'drop_off_address', 'status',
'driver', 'rider',
'created', 'updated',
)
list_filter = (
'status',
)
readonly_fields = (
'id', 'created', 'updated',
)
By default, our TripSerializer
processes related models as primary keys. That is the exact behavior that we want when we use a serializer to create a database record. On the other hand, when we get the serialized Trip
data back from the server, we want more information about the rider and the driver than just their database IDs.
Let’s create a new NestedTripSerializer
after our existing TripSerializer
. The difference is that the NestedTripSerializer
serializes the full User
object instead of its primary key.
class NestedTripSerializer(serializers.ModelSerializer):
class Meta:
model = Trip
fields = '__all__'
depth = 1
Now, let’s create a new test:
async def test_request_trip(self, settings):
settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS
user, access = await create_user(
'test.user@example.com', 'pAssw0rd', 'rider'
)
communicator = WebsocketCommunicator(
application=application,
path=f'/taxi/?token={access}'
)
connected, _ = await communicator.connect()
await communicator.send_json_to({
'type': 'create.trip',
'data': {
'pick_up_address': '123 Main Street',
'drop_off_address': '456 Piney Road',
'rider': user.id,
},
})
response = await communicator.receive_json_from()
response_data = response.get('data')
assert response_data['id'] is not None
assert response_data['pick_up_address'] == '123 Main Street'
assert response_data['drop_off_address'] == '456 Piney Road'
assert response_data['status'] == 'REQUESTED'
assert response_data['rider']['username'] == user.username
assert response_data['driver'] is None
await communicator.disconnect()
When a rider requests a trip, the server will create a new Trip
record and will broadcast the request to the driver pool. But from the rider’s perspective, he will only get a message back confirming the creation of a new trip. That’s what this test does. (We’ll prove that drivers receive the broadcast message in another test.)
Let’s make some changes to our consumer.
# trips/consumers.py
# changed
async def receive_json(self, content, **kwargs):
message_type = content.get('type')
if message_type == 'create.trip':
await self.create_trip(content)
elif message_type == 'echo.message':
await self.echo_message(content)
# new
async def create_trip(self, message):
data = message.get('data')
trip = await self._create_trip(data)
await self.send_json({
'type': 'echo.message',
'data': NestedTripSerializer(trip).data,
})
# changed
async def echo_message(self, message):
await self.send_json(message)
# new
@database_sync_to_async
def _create_trip(self, data):
serializer = TripSerializer(data=data)
serializer.is_valid(raise_exception=True)
return serializer.create(serializer.validated_data)
Add these imports:
from trips.serializers import NestedTripSerializer, TripSerializer
All incoming messages are received by the receive_json()
method in the consumer. Here is where you should delegate the business logic to process different message types. Our create_trip()
method creates a new trip and passes the details back to the client. Note that we are using a special decorated _create_trip()
helper method to do the actual database update.
Run the tests now to confirm that they pass.
9 passed in 1.37s
Broadcasting a Request
A ride request should be broadcast to all drivers in the driver pool the moment it is sent. Let’s create a test to capture that behavior.
# tests/test_websocket.py
async def test_driver_alerted_on_request(self, settings):
settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS
# Listen to the 'drivers' group test channel.
channel_layer = get_channel_layer()
await channel_layer.group_add(
group='drivers',
channel='test_channel'
)
user, access = await create_user(
'test.user@example.com', 'pAssw0rd', 'rider'
)
communicator = WebsocketCommunicator(
application=application,
path=f'/taxi/?token={access}'
)
connected, _ = await communicator.connect()
# Request a trip.
await communicator.send_json_to({
'type': 'create.trip',
'data': {
'pick_up_address': '123 Main Street',
'drop_off_address': '456 Piney Road',
'rider': user.id,
},
})
# Receive JSON message from server on test channel.
response = await channel_layer.receive('test_channel')
response_data = response.get('data')
assert response_data['id'] is not None
assert response_data['rider']['username'] == user.username
assert response_data['driver'] is None
await communicator.disconnect()
We start off by creating a channel layer and adding it to the driver pool. Every message that is broadcast to the drivers group will be captured on the test_channel
. Next, we establish a connection to the server as a rider, and we send a new request message over the wire. Finally, we wait for the broadcast message to reach the drivers group, and we confirm the identity of the rider who sent it.
Let’s add the missing functionality to our consumer.
# trips/consumers.py
async def create_trip(self, message):
data = message.get('data')
trip = await self._create_trip(data)
trip_data = NestedTripSerializer(trip).data
# Send rider requests to all drivers.
await self.channel_layer.group_send(group='drivers', message={
'type': 'echo.message',
'data': trip_data
})
await self.send_json({
'type': 'echo.message',
'data': trip_data,
})
Run the tests again to confirm that they’re passing.
10 passed in 1.53s
Listening for an Update
We’re not done yet. We’ve handled creating a trip and broadcasting it to drivers, but we haven’t built a mechanism for receiving messages back from the drivers yet. Remember, when the rider sends a request, we create a Trip
record and link him to it. We’re missing the piece that associates the correct communication channel with that rider.
We need to add two pieces of functionality: 1. Create a group for the new Trip
record and add the rider to it. 2. Add the rider to all of trip-related groups he belongs to when the WebSocket connects and remove him from them when the WebSocket disconnects.
Let’s create a test for the first part:
# tests/test_websocket.py
async def test_create_trip_group(self, settings):
settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS
user, access = await create_user(
'test.user@example.com', 'pAssw0rd', 'rider'
)
communicator = WebsocketCommunicator(
application=application,
path=f'/taxi/?token={access}'
)
connected, _ = await communicator.connect()
# Send a ride request.
await communicator.send_json_to({
'type': 'create.trip',
'data': {
'pick_up_address': '123 Main Street',
'drop_off_address': '456 Piney Road',
'rider': user.id,
},
})
response = await communicator.receive_json_from()
response_data = response.get('data')
# Send a message to the trip group.
message = {
'type': 'echo.message',
'data': 'This is a test message.',
}
channel_layer = get_channel_layer()
await channel_layer.group_send(response_data['id'], message=message)
# Rider receives message.
response = await communicator.receive_json_from()
assert response == message
await communicator.disconnect()
Here’s the update we need to make to the consumer.
async def create_trip(self, message):
data = message.get('data')
trip = await self._create_trip(data)
trip_data = NestedTripSerializer(trip).data
# Send rider requests to all drivers.
await self.channel_layer.group_send(group='drivers', message={
'type': 'echo.message',
'data': trip_data
})
# Add rider to trip group.
await self.channel_layer.group_add( # new
group=f'{trip.id}',
channel=self.channel_name
)
await self.send_json({
'type': 'echo.message',
'data': trip_data,
})
Now, let’s create a test for the second part. Make a new create_trip()
function right beneath the create_user()
function.
# tests/test_websocket.py
from trips.models import Trip
@database_sync_to_async
def create_trip(
pick_up_address='123 Main Street',
drop_off_address='456 Piney Road',
status='REQUESTED',
rider=None,
driver=None
):
return Trip.objects.create(
pick_up_address=pick_up_address,
drop_off_address=drop_off_address,
status=status,
rider=rider,
driver=driver
)
Next, add the test.
# tests/test_websocket.py
async def test_join_trip_group_on_connect(self, settings):
settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS
user, access = await create_user(
'test.user@example.com', 'pAssw0rd', 'rider'
)
trip = await create_trip(rider=user)
communicator = WebsocketCommunicator(
application=application,
path=f'/taxi/?token={access}'
)
connected, _ = await communicator.connect()
# Send a message to the trip group.
message = {
'type': 'echo.message',
'data': 'This is a test message.',
}
channel_layer = get_channel_layer()
await channel_layer.group_send(f'{trip.id}', message=message)
# Rider receives message.
response = await communicator.receive_json_from()
assert response == message
await communicator.disconnect()
This time we need to update the consumer’s connect()
and disconnect()
methods:
# trips/consumers.py
class TaxiConsumer(AsyncJsonWebSocketConsumer):
async def connect(self):
user = self.scope['user']
if user.is_anonymous:
await self.close()
else:
user_group = await self._get_user_group(user)
if user_group == 'driver':
await self.channel_layer.group_add(
group='drivers',
channel=self.channel_name
)
# new
for trip_id in await self._get_trip_ids(user):
await self.channel_layer.group_add(
group=trip_id,
channel=self.channel_name
)
await self.accept()
# Other methods hidden for clarity.
async def disconnect(self, code):
user = self.scope['user']
user_group = await self._get_user_group(user)
if user_group == 'driver':
await self.channel_layer.group_discard(
group='drivers',
channel=self.channel_name
)
# new
for trip_id in await self._get_trip_ids(user):
await self.channel_layer.group_discard(
group=trip_id,
channel=self.channel_name
)
await super().disconnect(code)
Let’s add the _get_trip_ids()
helper method to the consumer too:
# trips/consumers.py
from trips.models import Trip
@database_sync_to_async
def _get_trip_ids(self, user):
user_groups = user.groups.values_list('name', flat=True)
if 'driver' in user_groups:
trip_ids = user.trips_as_driver.exclude(
status=Trip.COMPLETED
).only('id').values_list('id', flat=True)
else:
trip_ids = user.trips_as_rider.exclude(
status=Trip.COMPLETED
).only('id').values_list('id', flat=True)
return map(str, trip_ids)
You should notice that our helper function processes the Trip
records according to whether the user is a rider or a driver. Regardless of what group the user belongs to, he will be added to and removed from the correct communication channels.
With these changes in place, all tests should pass:
12 passed in 1.80s
Review
Let’s review what we’ve learned so far. The WebSocket protocol works differently then HTTP. Whereas HTTP requests only last as long as the request/response cycle, WebSockets live until one of the two parties involved breaks the connection. Both the client and the server can send each other messages independently over the open connection.
With Django Channels, when the client sends a message to the server, the server sends a message back. The server can also send messages to other connected clients through mechanisms known as channel layers and groups. If a message is broadcast to a group, the server sends the message to every channel in that group.
So far, we have two categories of users – drivers and riders. Drivers belong to a drivers group by default. When a rider requests a trip, the server creates both a Trip
record in the database and a corresponding communication group identified by that record’s primary key. The rider is linked to the database record and is added to the group. The act of requesting a trip prompts the server to alert every driver in the system.
In the next chapter, we’ll wrap up the server functionality.
WebSockets - Part Three
Part 1, Chapter 8
Accepting a Request
A rider has sent a request to the server. The server has broadcast the request to everyone in the driver pool. Now what? A driver needs to accept the request and start driving to the pick up address.
Updating a Trip
When a driver accepts a trip request, we need to tell the rider who requested the trip that the request has been filled.
Let’s write a test to show that the rider is updated when a driver accepts the request:
# tests/test_websocket.py
async def test_driver_can_update_trip(self, settings):
settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS
# Create trip request.
rider, _ = await create_user(
'test.rider@example.com', 'pAssw0rd', 'rider'
)
trip = await create_trip(rider=rider)
trip_id = f'{trip.id}'
# Listen for messages as rider.
channel_layer = get_channel_layer()
await channel_layer.group_add(
group=trip_id,
channel='test_channel'
)
# Update trip.
driver, access = await create_user(
'test.driver@example.com', 'pAssw0rd', 'driver'
)
communicator = WebsocketCommunicator(
application=application,
path=f'/taxi/?token={access}'
)
connected, _ = await communicator.connect()
message = {
'type': 'update.trip',
'data': {
'id': trip_id,
'pick_up_address': trip.pick_up_address,
'drop_off_address': trip.drop_off_address,
'status': Trip.IN_PROGRESS,
'driver': driver.id,
},
}
await communicator.send_json_to(message)
# Rider receives message.
response = await channel_layer.receive('test_channel')
response_data = response.get('data')
assert response_data['id'] == trip_id
assert response_data['rider']['username'] == rider.username
assert response_data['driver']['username'] == driver.username
await communicator.disconnect()
In the test above, we create a rider and a trip and then we start listening on the communication channel associated with the trip. Next, we create a driver and send a message to the server to update the trip. Lastly, we confirm that the message gets broadcast to the rider.
Let’s add the new functionality to the consumer:
# trips/consumers.py
async def receive_json(self, content, **kwargs):
message_type = content.get('type')
if message_type == 'create.trip':
await self.create_trip(content)
elif message_type == 'echo.message':
await self.echo_message(content)
elif message_type == 'update.trip': # new
await self.update_trip(content)
# new
async def update_trip(self, message):
data = message.get('data')
trip = await self._update_trip(data)
trip_id = f'{trip.id}'
trip_data = NestedTripSerializer(trip).data
# Send update to rider.
await self.channel_layer.group_send(
group=trip_id,
message={
'type': 'echo.message',
'data': trip_data,
}
)
# Add driver to the trip group.
await self.channel_layer.group_add(
group=trip_id,
channel=self.channel_name
)
await self.send_json({
'type': 'echo.message',
'data': trip_data
})
# new
@database_sync_to_async
def _update_trip(self, data):
instance = Trip.objects.get(id=data.get('id'))
serializer = TripSerializer(data=data)
serializer.is_valid(raise_exception=True)
return serializer.update(instance, serializer.validated_data)
The update trip functionality is almost the mirror image of the create trip functionality. A driver sends a message to update a trip. The relevant Trip
record gets updated to include the driver. The server broadcasts a message to everyone in the trip group with the updated trip information. The server adds the driver to trip group.
Run pytest
and confirm that the new test passes:
13 passed in 1.93s
Joining a Trip
In the previous chapter, we wrote a test to prove that riders were added to the appropriate communication channels when they reconnected to the server. Let’s write a similar test for drivers. Remember, we’re already testing the functionality that adds drivers to the driver pool on login. This test confirms that once a driver has joined a trip (by accepting a trip request), he is added back to the trip’s communication channel when he reconnects.
Add the following test:
# tests/test_websocket.py
async def test_driver_join_trip_group_on_connect(self, settings):
settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS
user, access = await create_user(
'test.user@example.com', 'pAssw0rd', 'driver'
)
trip = await create_trip(driver=user)
communicator = WebsocketCommunicator(
application=application,
path=f'/taxi/?token={access}'
)
connected, _ = await communicator.connect()
# Send a message to the trip group.
message = {
'type': 'echo.message',
'data': 'This is a test message.',
}
channel_layer = get_channel_layer()
await channel_layer.group_send(f'{trip.id}', message=message)
# Rider receives message.
response = await communicator.receive_json_from()
assert response == message
await communicator.disconnect()
There’s no new functionality to add to the consumer.
One last test run:
14 passed in 2.06s
Independent Study
We’ve finished coding and testing some basic user workflows, but there are many other feature enhancements that we could add. Using what you learned, can you figure out how to program these common scenarios?
- The rider cancels his request after a driver accepts it.
- The server alerts all other drivers in the driver pool that someone has accepted a request.
- The driver periodically broadcasts his location to the rider during a trip.
- The server only allows a rider to request one trip at a time.
- The rider can share his trip with another rider, who can join the trip and receive updates.
- The server only shares a trip request to drivers in a specific geographic location.
- If no drivers accept the request within a certain timespan, the server cancels the request and returns a message to the rider.
UI Support
Part 1, Chapter 9
Up until now, we haven’t had a reason to track users as drivers or riders. Users can be either. But as soon as we add a UI, we will need a way for users to sign up with a role. Drivers will see a different UI and will experience different functionality than riders.
The first thing we need to do is add support for user groups in our serializers in trips/serializers.py .
Serializer
First, update the UserSerializer
serializer:
class UserSerializer(serializers.ModelSerializer):
password1 = serializers.CharField(write_only=True)
password2 = serializers.CharField(write_only=True)
group = serializers.CharField() # new
def validate(self, data):
if data['password1'] != data['password2']:
raise serializers.ValidationError('Passwords must match.')
return data
def create(self, validated_data): # changed
group_data = validated_data.pop('group')
group, _ = Group.objects.get_or_create(name=group_data)
data = {
key: value for key, value in validated_data.items()
if key not in ('password1', 'password2')
}
data['password'] = validated_data['password1']
user = self.Meta.model.objects.create_user(**data)
user.groups.add(group)
user.save()
return user
class Meta:
model = get_user_model()
fields = (
'id', 'username', 'password1', 'password2',
'first_name', 'last_name', 'group', # new
)
read_only_fields = ('id',)
Add the import at the top, too.
from django.contrib.auth.models import Group
Model
Next, update the custom user model in trips/models.py to support groups as well:
# trips/models.py
class User(AbstractUser):
@property
def group(self):
groups = self.groups.all()
return groups[0].name if groups else None
We are not adding any database fields so we don’t need to create a new migration.
View
Lastly, add the proper filters to the TripView
in trips/views.py
class TripView(viewsets.ReadOnlyModelViewSet):
lookup_field = 'id'
lookup_url_kwarg = 'trip_id'
permission_classes = (permissions.IsAuthenticated,)
serializer_class = TripSerializer
def get_queryset(self): # new
user = self.request.user
if user.group == 'driver':
return Trip.objects.filter(
Q(status=Trip.REQUESTED) | Q(driver=user)
)
if user.group == 'rider':
return Trip.objects.filter(rider=user)
return Trip.objects.none()
Note that we removed the queryset = Trip.objects.all() field.
Add the import at the top:
from django.db.models import Q
Test
Re-run the tests:
(env)$ python manage.py test trips.tests
You should see three failures, with the following errors:
AssertionError: 201 != 400
...
AssertionError: Element counts were not equal:
First has 1, Second has 0: '8c1b7e58-6e39-46c2-ac83-b934b6a8c172'
First has 1, Second has 0: '258f55ed-72c1-488b-bd2d-4ff42cbb8bfd'
...
AssertionError: 200 != 404
First, within trips/tests/test_http.py, update the create_user() function to take an additional group_name parameter:
# tests/test_http.py
def create_user(username='user@example.com', password=PASSWORD, group_name='rider'):
group, _ = Group.objects.get_or_create(name=group_name)
user = get_user_model().objects.create_user(
username=username, password=password)
user.groups.add(group)
user.save()
return user
And add the group to the test_user_can_sign_up test in AuthenticationTest:
# tests/test_http.py
def test_user_can_sign_up(self):
response = self.client.post(reverse('sign_up'), data={
'username': 'user@example.com',
'first_name': 'Test',
'last_name': 'User',
'password1': PASSWORD,
'password2': PASSWORD,
'group': 'rider', # new
})
user = get_user_model().objects.last()
self.assertEqual(status.HTTP_201_CREATED, response.status_code)
self.assertEqual(response.data['id'], user.id)
self.assertEqual(response.data['username'], user.username)
self.assertEqual(response.data['first_name'], user.first_name)
self.assertEqual(response.data['last_name'], user.last_name)
self.assertEqual(response.data['group'], user.group) # new
Add the import:
from django.contrib.auth.models import Group
Run the tests again. You should see only two failures, with the following errors:
...
AssertionError: Element counts were not equal:
First has 1, Second has 0: 'd09af51e-040e-4161-954b-2b821dd49b31'
First has 1, Second has 0: 'f67a3d2a-5aa2-4f9e-a467-5c11da7dd617'
...
AssertionError: 200 != 404
Next, update HttpTripTest:
# tests/test_http.py
class HttpTripTest(APITestCase):
def setUp(self):
self.user = create_user() # changed
self.client.login(username=self.user.username, password=PASSWORD) # changed
def test_user_can_list_trips(self): # changed
trips = [
Trip.objects.create(
pick_up_address='A', drop_off_address='B', rider=self.user),
Trip.objects.create(
pick_up_address='B', drop_off_address='C', rider=self.user),
Trip.objects.create(
pick_up_address='C', drop_off_address='D')
]
response = self.client.get(reverse('trip:trip_list'))
self.assertEqual(status.HTTP_200_OK, response.status_code)
exp_trip_ids = [str(trip.id) for trip in trips[0:2]]
act_trip_ids = [trip.get('id') for trip in response.data]
self.assertCountEqual(act_trip_ids, exp_trip_ids)
def test_user_can_retrieve_trip_by_id(self): # changed
trip = Trip.objects.create(
pick_up_address='A', drop_off_address='B', rider=self.user)
response = self.client.get(trip.get_absolute_url())
self.assertEqual(status.HTTP_200_OK, response.status_code)
self.assertEqual(str(trip.id), response.data.get('id'))
After modifications, all tests should pass:
(env)$ python manage.py test trips.tests
Creating test database for alias 'default'...
System check identified no issues (0 silenced).
.....
----------------------------------------------------------------------
Ran 4 tests in 1.261s
OK
Destroying test database for alias 'default'...
Finally, run the server and test out the DRF Browsable API:
http://localhost:8000/api/sign_up/
http://localhost:8000/api/log_in/
drf sign up page
User Photos
Part 1, Chapter 10
Viewing a user’s photo is an important piece of functionality in ride-sharing apps. In fact, most of these apps make it mandatory to provide a photo before you can drive or ride. From one perspective, it’s a security issue – riders need to confirm that the drivers are who they expect before they enter their vehicles. User photos are also good design and add life to the product.
Our app will allow users to add their photos at sign up.
Media Files
Media files are a form of user-generated static files and Django handles both in a similar way. We need to provide two new settings, MEDIA_ROOT and MEDIA_URL.
The MEDIA_ROOT is the path to the directory where file uploads will be saved. For the purpose of this tutorial, we can create a “media” folder inside our “server” directory. In a production environment, we’d specify an absolute path to a directory on the server or we’d store files with a service like AWS S3. The MEDIA_URL is the prefix to use in our URL path.
Set both the MEDIA_URL and the MEDIA_ROOT within the settings file:
# taxi/settings.py
MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(BASE_DIR, '../media')
One last step is required to get our local environment to serve media files. Update taxi/urls.py like so:
from django.conf import settings # new
from django.conf.urls.static import static # new
from django.contrib import admin
from django.urls import include, path
from rest_framework_simplejwt.views import TokenRefreshView
from trips.views import SignUpView, LogInView
urlpatterns = [
path('admin/', admin.site.urls),
path('api/sign_up/', SignUpView.as_view(), name='sign_up'),
path('api/log_in/', LogInView.as_view(), name='log_in'),
path('api/token/refresh/', TokenRefreshView.as_view(),
name='token_refresh'),
path('api/trip/', include('trips.urls', 'trip',)),
] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) # new
Media files can now be retrieved via http://localhost:8000/media/<file_path>/ on your local machine.
To test, add a new folder called “media” to “server”. Then, add a test file called test.txt to that folder and add some random text to the file. Fire up the server and navigate to http://localhost:8000/media/test.txt to view the file.
Make sure you remove the static() function from the urlpatterns when you deploy your application. We only need that for local development.
Tests
Change the existing AuthenticationTest.test_user_can_sign_up test in the following way:
# tests/test_http.py
def test_user_can_sign_up(self):
photo_file = create_photo_file() # new
response = self.client.post(reverse('sign_up'), data={
'username': 'user@example.com',
'first_name': 'Test',
'last_name': 'User',
'password1': PASSWORD,
'password2': PASSWORD,
'group': 'rider',
'photo': photo_file, # new
})
user = get_user_model().objects.last()
self.assertEqual(status.HTTP_201_CREATED, response.status_code)
self.assertEqual(response.data['id'], user.id)
self.assertEqual(response.data['username'], user.username)
self.assertEqual(response.data['first_name'], user.first_name)
self.assertEqual(response.data['last_name'], user.last_name)
self.assertEqual(response.data['group'], user.group)
self.assertIsNotNone(user.photo) # new
Add the create_photo_file helper function right after the create_user helper:
# trips/tests/test_http.py
def create_photo_file():
data = BytesIO()
Image.new('RGB', (100, 100)).save(data, 'PNG')
data.seek(0)
return SimpleUploadedFile('photo.png', data.getvalue())
This code leverages the Pillow library, BytesIO from the standard library, and Django’s SimpleUploadedFile to create fake image data.
Add the imports:
# trips/tests/test_http.py
from io import BytesIO
from PIL import Image
from django.core.files.uploadedfile import SimpleUploadedFile
Of course the test will fail since we need to update our user model and its serializer:
AttributeError: 'User' object has no attribute 'photo'
File Changes
Modify the user model:
# trips/models.py
class User(AbstractUser):
photo = models.ImageField(upload_to='photos', null=True, blank=True) # new
@property
def group(self):
groups = self.groups.all()
return groups[0].name if groups else None
Now, when users upload their photos, the app will save them in a photos subdirectory within our media folder.
Edit the serializers.py file to include the photo field in UserSerializer.
# trips/serializers.py
class UserSerializer(serializers.ModelSerializer):
password1 = serializers.CharField(write_only=True)
password2 = serializers.CharField(write_only=True)
group = serializers.CharField()
def validate(self, data):
if data['password1'] != data['password2']:
raise serializers.ValidationError('Passwords must match.')
return data
def create(self, validated_data):
group_data = validated_data.pop('group')
group, _ = Group.objects.get_or_create(name=group_data)
data = {
key: value for key, value in validated_data.items()
if key not in ('password1', 'password2')
}
data['password'] = validated_data['password1']
user = self.Meta.model.objects.create_user(**data)
user.groups.add(group)
user.save()
return user
class Meta:
model = get_user_model()
fields = (
'id', 'username', 'password1', 'password2',
'first_name', 'last_name', 'group',
'photo', # new
)
read_only_fields = ('id',)
Edit the views.py file to make TripView use the NestedTripSerializer. We want the Trip API response payload to include full driver and rider object representations.
# trips/views.py
from django.contrib.auth import get_user_model
from django.db.models import Q
from rest_framework import generics, permissions, viewsets
from rest_framework_simplejwt.views import TokenObtainPairView
from .models import Trip
from .serializers import LogInSerializer, NestedTripSerializer, UserSerializer # changed
class SignUpView(generics.CreateAPIView):
queryset = get_user_model().objects.all()
serializer_class = UserSerializer
class LogInView(TokenObtainPairView):
serializer_class = LogInSerializer
class TripView(viewsets.ReadOnlyModelViewSet):
lookup_field = 'id'
lookup_url_kwarg = 'trip_id'
permission_classes = (permissions.IsAuthenticated,)
serializer_class = NestedTripSerializer # changed
def get_queryset(self):
user = self.request.user
if user.group == 'driver':
return Trip.objects.filter(
Q(status=Trip.REQUESTED) | Q(driver=user)
)
if user.group == 'rider':
return Trip.objects.filter(rider=user)
return Trip.objects.none()
One last thing—create a migration for the new photo field on our user table and run the migrations.
(env)$ python manage.py makemigrations
(env)$ python manage.py migrate
Now the tests should pass.
(env)$ python manage.py test trips.tests
Creating test database for alias 'default'...
System check identified no issues (0 silenced).
.....
----------------------------------------------------------------------
Ran 5 tests in 0.972s
OK
Destroying test database for alias 'default'...
Here’s our directory structure in its current state.
.
├── pytest.ini
└── server
├── media
│ └── test.txt
└── taxi
├── db.sqlite3
├── manage.py
├── taxi
│ ├── __init__.py
│ ├── asgi.py
│ ├── routing.py
│ ├── settings.py
│ ├── urls.py
│ └── wsgi.py
└── trips
├── __init__.py
├── admin.py
├── apps.py
├── consumers.py
├── migrations
│ ├── 0001_initial.py
│ ├── 0002_trip.py
│ ├── 0003_trip_driver_rider.py
│ ├── 0004_user_photo.py
│ └── __init__.py
├── models.py
├── serializers.py
├── tests
│ ├── __init__.py
│ ├── test_http.py
│ └── test_websockets.py
├── urls.py
└── views.py
Conclusion
Part 1, Chapter 11
Part 1 of this tutorial covered a lot of material. We implemented JSON Web Token authentication on the server-side, queried the server for data with RESTful APIs, and used WebSocket messages (via Django Channels) to create data and send alerts to users. We also used test-driven development to plan and test our features. Parts 2 and 3 will delve into the UI programming for our app.