File size: 4,227 Bytes
ddda2fe
 
 
 
 
 
 
50974f8
ddda2fe
50974f8
ddda2fe
50974f8
 
 
ddda2fe
50974f8
 
ddda2fe
 
50974f8
ddda2fe
 
50974f8
 
 
 
 
ddda2fe
 
 
 
 
 
50974f8
ddda2fe
50974f8
ddda2fe
50974f8
 
ddda2fe
50974f8
 
ddda2fe
 
50974f8
 
 
ddda2fe
 
 
 
50974f8
ddda2fe
50974f8
 
ddda2fe
6857b34
50974f8
3c21054
 
ddda2fe
 
50974f8
 
ddda2fe
 
50974f8
ddda2fe
50974f8
ddda2fe
50974f8
 
 
ddda2fe
 
50974f8
ddda2fe
50974f8
 
ddda2fe
50974f8
 
ddda2fe
50974f8
 
ddda2fe
 
 
 
 
 
50974f8
ddda2fe
50974f8
 
 
fcb3586
 
 
 
50974f8
0747c82
 
 
50974f8
 
 
 
ddda2fe
 
 
50974f8
 
ddda2fe
 
 
 
50974f8
ddda2fe
50974f8
ddda2fe
 
 
50974f8
 
 
 
 
ddda2fe
 
 
 
 
50974f8
ddda2fe
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import sys
import csv
import random
from typing import List


def make_unique_lines(f_in: str, f_out: str) -> int:
    """
    Delete all duplicate lines of a file.

    Args:
        f_in (str): File path to analyse, must contain extension.
        f_out (str): File path containing result, must contain extension.

    Returns:
        (int): The number of duplicate lines found.
    """

    seen_lines: set = set()
    duplicates: int = 0

    with open(f_in, "r") as in_f, open(f_out, "w") as out_f:
        for line in in_f:
            if line not in seen_lines:
                out_f.write(line)
                seen_lines.add(line)
            else:
                duplicates += 1

    return duplicates


def count_file_lines(f_path: str) -> int:
    """
    Count the number of lines in a file.

    Args:
        f_path (str): File path to analyse, must contain extension.

    Returns:
        (int): The number of lines found.
    """

    with open(f_path, "r") as f:
        lines = f.readlines()
        return len(lines)


def get_cities() -> List:
    """
    Returns all cities from sncf db_file.

    Returns:
        (List): All cities present in file.
    """

    cities = []
    stations_file = os.path.join(".", "data", "sncf", "sncf_stations_database.csv")
    with open(stations_file, "r") as csvfile:
        reader = csv.DictReader(csvfile, delimiter=";")
        for row in reader:
            cities.append(row["COMMUNE"])
    return cities


def generate_data(cities: List, file_out: str, nb_samples: int):
    """
    Generate dataset from template file.

    Args:
        cities (List): Cities from wich combinaison will generate.
        file_out (str): Output file, must contain extension.
    """

    user_comb = set()
    cities = get_cities()
    print(len(cities))
    line_count = 0

    with open("data_unique_tmp.txt", "r") as f_template:
        template_line = f_template.readlines()

    with open(file_out, "w") as f_sortie:
        while line_count <= nb_samples:
            arrival_city = random.choice(cities)
            departure_city = random.choice(cities)

            while arrival_city == departure_city:
                arrival_city = random.choice(cities)

            combination = (arrival_city, departure_city)

            if combination not in user_comb:
                user_comb.add(combination)
                line = random.choice(template_line)
                # .title() is used to capitalize the first letter of each word
                # in the string, the purpose if it is to have a unified format
                new_line = line.replace("{depart}", departure_city.title())
                new_line = new_line.replace("{arrivee}", arrival_city.title())
                try:
                    n_chars_written = f_sortie.write(new_line)
                    if n_chars_written != len(new_line):
                        raise Exception("Error while writing line")
                    line_count += 1
                except Exception as e:
                    print(e)
                    print(new_line)


def main():
    if len(sys.argv) != 4:
        print("Usage: python script.py <file_in> <file_out> <nb_sample>")
        sys.exit(1)
    else:
        file_in = sys.argv[1]
        file_out = sys.argv[2]
        nb_samples = int(sys.argv[3])

        duplicates: int = make_unique_lines(file_in, "data_unique_tmp.txt")

        cities: List = get_cities()

        generate_data(cities, file_out, nb_samples)

        initial_line_number: int = count_file_lines(file_in)
        unique_sentences_number: int = count_file_lines("data_unique_tmp.txt")
        final_data_number: int = count_file_lines(file_out)

        os.remove("data_unique_tmp.txt")

        print("Treatment is finished : ")
        print("     - Input file : " + file_in)
        print("     - Unitial number of lines : " + str(initial_line_number))
        print("     - Number of duplicates found : " + str(duplicates))
        print("     - Unique sentence forms : " + str(unique_sentences_number))
        print("     - Output file : " + file_out)
        print("     - Final dataset size : " + str(final_data_number))


if __name__ == "__main__":
    main()