diff --git a/backend/server/worldtravel/management/commands/download-countries.py b/backend/server/worldtravel/management/commands/download-countries.py index 411f55b..aff0f2e 100644 --- a/backend/server/worldtravel/management/commands/download-countries.py +++ b/backend/server/worldtravel/management/commands/download-countries.py @@ -51,65 +51,94 @@ class Command(BaseCommand): with open(countries_json_path, 'r') as f: data = json.load(f) - countries_to_create = [] - regions_to_create = [] - - for country in data: - country_code = country['iso2'] - country_name = country['name'] - country_subregion = country['subregion'] - - country_obj = Country( - name=country_name, - country_code=country_code, - subregion=country_subregion - ) - countries_to_create.append(country_obj) - - saveCountryFlag(country_code) - self.stdout.write(self.style.SUCCESS(f'Country {country_name} prepared')) - - # Bulk create countries first with transaction.atomic(): - Country.objects.bulk_create(countries_to_create, ignore_conflicts=True) + existing_countries = {country.country_code: country for country in Country.objects.all()} + existing_regions = {region.id: region for region in Region.objects.all()} - # Fetch all countries to get their database IDs - countries = {country.country_code: country for country in Country.objects.all()} + countries_to_create = [] + regions_to_create = [] + countries_to_update = [] + regions_to_update = [] - for country in data: - country_code = country['iso2'] - country_obj = countries[country_code] - - if country['states']: - for state in country['states']: - name = state['name'] - state_id = f"{country_code}-{state['state_code']}" - latitude = round(float(state['latitude']), 6) if state['latitude'] else None - longitude = round(float(state['longitude']), 6) if state['longitude'] else None + processed_country_codes = set() + processed_region_ids = set() - region_obj = Region( - id=state_id, - name=name, - country=country_obj, - longitude=longitude, - latitude=latitude + for country in data: + country_code = country['iso2'] + country_name = country['name'] + country_subregion = country['subregion'] + + processed_country_codes.add(country_code) + + if country_code in existing_countries: + country_obj = existing_countries[country_code] + country_obj.name = country_name + country_obj.subregion = country_subregion + countries_to_update.append(country_obj) + else: + country_obj = Country( + name=country_name, + country_code=country_code, + subregion=country_subregion ) - regions_to_create.append(region_obj) - self.stdout.write(self.style.SUCCESS(f'State {state_id} prepared')) - else: - # Create one region with the name of the country if there are no states - region_obj = Region( - id=f"{country_code}-00", - name=country['name'], - country=country_obj - ) - regions_to_create.append(region_obj) - self.stdout.write(self.style.SUCCESS(f'Region {country_code}-00 prepared for {country["name"]}')) + countries_to_create.append(country_obj) - # Bulk create regions - with transaction.atomic(): - Region.objects.bulk_create(regions_to_create, ignore_conflicts=True) + saveCountryFlag(country_code) + self.stdout.write(self.style.SUCCESS(f'Country {country_name} prepared')) - self.stdout.write(self.style.SUCCESS('All data imported successfully')) + if country['states']: + for state in country['states']: + name = state['name'] + state_id = f"{country_code}-{state['state_code']}" + latitude = round(float(state['latitude']), 6) if state['latitude'] else None + longitude = round(float(state['longitude']), 6) if state['longitude'] else None - + processed_region_ids.add(state_id) + + if state_id in existing_regions: + region_obj = existing_regions[state_id] + region_obj.name = name + region_obj.country = country_obj + region_obj.longitude = longitude + region_obj.latitude = latitude + regions_to_update.append(region_obj) + else: + region_obj = Region( + id=state_id, + name=name, + country=country_obj, + longitude=longitude, + latitude=latitude + ) + regions_to_create.append(region_obj) + self.stdout.write(self.style.SUCCESS(f'State {state_id} prepared')) + else: + state_id = f"{country_code}-00" + processed_region_ids.add(state_id) + if state_id in existing_regions: + region_obj = existing_regions[state_id] + region_obj.name = country_name + region_obj.country = country_obj + regions_to_update.append(region_obj) + else: + region_obj = Region( + id=state_id, + name=country_name, + country=country_obj + ) + regions_to_create.append(region_obj) + self.stdout.write(self.style.SUCCESS(f'Region {state_id} prepared for {country_name}')) + + # Bulk create new countries and regions + Country.objects.bulk_create(countries_to_create) + Region.objects.bulk_create(regions_to_create) + + # Bulk update existing countries and regions + Country.objects.bulk_update(countries_to_update, ['name', 'subregion']) + Region.objects.bulk_update(regions_to_update, ['name', 'country', 'longitude', 'latitude']) + + # Delete countries and regions that are no longer in the data + Country.objects.exclude(country_code__in=processed_country_codes).delete() + Region.objects.exclude(id__in=processed_region_ids).delete() + + self.stdout.write(self.style.SUCCESS('All data imported successfully')) \ No newline at end of file diff --git a/backend/server/worldtravel/models.py b/backend/server/worldtravel/models.py index d54fe95..66e10d4 100644 --- a/backend/server/worldtravel/models.py +++ b/backend/server/worldtravel/models.py @@ -44,4 +44,4 @@ class VisitedRegion(models.Model): def save(self, *args, **kwargs): if VisitedRegion.objects.filter(user_id=self.user_id, region=self.region).exists(): raise ValidationError("Region already visited by user.") - super().save(*args, **kwargs) + super().save(*args, **kwargs) \ No newline at end of file