Showing
4 changed files
with
113 additions
and
18 deletions
... | @@ -10,6 +10,7 @@ public class PostprocessingExpert { | ... | @@ -10,6 +10,7 @@ public class PostprocessingExpert { |
10 | 10 | ||
11 | List<PostprocessingSample> sampleList; | 11 | List<PostprocessingSample> sampleList; |
12 | List<PostprocessingSample> recommendationList; | 12 | List<PostprocessingSample> recommendationList; |
13 | + List<PostprocessingSample> annotatedList; | ||
13 | Logger logger = LoggerFactory.getLogger(getClass()); | 14 | Logger logger = LoggerFactory.getLogger(getClass()); |
14 | 15 | ||
15 | Set<Long> sampleItemIds; | 16 | Set<Long> sampleItemIds; |
... | @@ -19,19 +20,30 @@ public class PostprocessingExpert { | ... | @@ -19,19 +20,30 @@ public class PostprocessingExpert { |
19 | int recommendableItemCount; | 20 | int recommendableItemCount; |
20 | int recommendedItemCount; | 21 | int recommendedItemCount; |
21 | int recommendableItemUserCount; | 22 | int recommendableItemUserCount; |
23 | + int annotatedItemUserCount; | ||
22 | int recommendedItemUserCount; | 24 | int recommendedItemUserCount; |
25 | + int validRecommendationCount; | ||
23 | 26 | ||
24 | - public PostprocessingExpert(List<PostprocessingSample> sampleList, List<PostprocessingSample> recommendationList) { | 27 | + public PostprocessingExpert(List<PostprocessingSample> sampleList, List<PostprocessingSample> recommendationList, |
28 | + List<PostprocessingSample> annotatedList) { | ||
25 | this.sampleList = sampleList; | 29 | this.sampleList = sampleList; |
26 | this.recommendationList = recommendationList; | 30 | this.recommendationList = recommendationList; |
31 | + this.annotatedList = annotatedList; | ||
27 | } | 32 | } |
28 | 33 | ||
29 | - public PostprocessingCoverage getCoverage() { | 34 | + public void analyze() { |
30 | analyzeSample(); | 35 | analyzeSample(); |
31 | analyzeRecommendations(); | 36 | analyzeRecommendations(); |
37 | + } | ||
38 | + | ||
39 | + public PostprocessingCoverage getCoverage() { | ||
32 | return computeCoverage(); | 40 | return computeCoverage(); |
33 | } | 41 | } |
34 | 42 | ||
43 | + public PostprocessingPrecisionRecall getPrecisionRecall() { | ||
44 | + return computePrecisionRecall(); | ||
45 | + } | ||
46 | + | ||
35 | protected void analyzeSample() { | 47 | protected void analyzeSample() { |
36 | 48 | ||
37 | sampleItemIds = new HashSet<>(); | 49 | sampleItemIds = new HashSet<>(); |
... | @@ -55,8 +67,8 @@ public class PostprocessingExpert { | ... | @@ -55,8 +67,8 @@ public class PostprocessingExpert { |
55 | } | 67 | } |
56 | 68 | ||
57 | recommendableItemCount = sampleItemIds.size(); | 69 | recommendableItemCount = sampleItemIds.size(); |
58 | - logger.trace("Nombre d'objets recommandables {}", recommendableItemCount); | 70 | + logger.trace("C: Nombre d'objets recommandables {}", recommendableItemCount); |
59 | - logger.trace("Taille de la matrice item-user {}", sampleItemIds.size() * sampleUserIds.size()); | 71 | + logger.trace("C: Taille de la matrice item-user {}", sampleItemIds.size() * sampleUserIds.size()); |
60 | 72 | ||
61 | int sampleCoupleCount = 0; | 73 | int sampleCoupleCount = 0; |
62 | for (Long itemId : sampleItemIds) { | 74 | for (Long itemId : sampleItemIds) { |
... | @@ -64,13 +76,18 @@ public class PostprocessingExpert { | ... | @@ -64,13 +76,18 @@ public class PostprocessingExpert { |
64 | } | 76 | } |
65 | 77 | ||
66 | recommendableItemUserCount = sampleItemIds.size() * sampleUserIds.size() - sampleCoupleCount; | 78 | recommendableItemUserCount = sampleItemIds.size() * sampleUserIds.size() - sampleCoupleCount; |
67 | - logger.trace("Nombre de couples item-user dans l'échantillon {}", sampleCoupleCount); | 79 | + logger.trace("C: Nombre de couples item-user dans l'échantillon {}", sampleCoupleCount); |
68 | - logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount); | 80 | + logger.trace("C: Nombre de couples item-user recommandables {}", recommendableItemUserCount); |
69 | } | 81 | } |
70 | 82 | ||
71 | protected void analyzeRecommendations() { | 83 | protected void analyzeRecommendations() { |
72 | recommendedItemUserCount = 0; | 84 | recommendedItemUserCount = 0; |
85 | + validRecommendationCount = 0; | ||
73 | recommendedItemIds = new HashSet<>(); | 86 | recommendedItemIds = new HashSet<>(); |
87 | + for (PostprocessingSample annote : annotatedList) { | ||
88 | + logger.trace("Annotated item {}, user {}", annote.getItemId(), annote.getUserId()); | ||
89 | + } | ||
90 | + | ||
74 | for (PostprocessingSample reco : recommendationList) { | 91 | for (PostprocessingSample reco : recommendationList) { |
75 | Long itemId = reco.getItemId(); | 92 | Long itemId = reco.getItemId(); |
76 | Long userId = reco.getUserId(); | 93 | Long userId = reco.getUserId(); |
... | @@ -81,10 +98,17 @@ public class PostprocessingExpert { | ... | @@ -81,10 +98,17 @@ public class PostprocessingExpert { |
81 | recommendedItemUserCount++; | 98 | recommendedItemUserCount++; |
82 | } | 99 | } |
83 | } | 100 | } |
101 | + logger.trace("Recommendation item {}, user {}", reco.getItemId(), reco.getUserId()); | ||
102 | + if (annotatedList.contains(reco)) { | ||
103 | + validRecommendationCount++; | ||
104 | + } | ||
84 | } | 105 | } |
85 | recommendedItemCount = recommendedItemIds.size(); | 106 | recommendedItemCount = recommendedItemIds.size(); |
86 | - logger.trace("Nombre d'objets recommandés {}", recommendedItemCount); | 107 | + logger.trace("C: Nombre d'objets recommandés {}", recommendedItemCount); |
87 | - logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount); | 108 | + logger.trace("C/PR: Nombre de couples item-user recommandés {}", recommendedItemUserCount); |
109 | + annotatedItemUserCount = annotatedList.size(); | ||
110 | + logger.trace("PR: Nombre d'associations annotées {}", annotatedItemUserCount); | ||
111 | + logger.trace("PR: Nombre de recommandations annotées {}", validRecommendationCount); | ||
88 | } | 112 | } |
89 | 113 | ||
90 | protected PostprocessingCoverage computeCoverage() { | 114 | protected PostprocessingCoverage computeCoverage() { |
... | @@ -92,12 +116,12 @@ public class PostprocessingExpert { | ... | @@ -92,12 +116,12 @@ public class PostprocessingExpert { |
92 | float c2; | 116 | float c2; |
93 | int c3; | 117 | int c3; |
94 | 118 | ||
95 | - logger.trace("Nombre d'objets recommandés {}", recommendedItemCount); | 119 | + logger.trace("C: Nombre d'objets recommandés {}", recommendedItemCount); |
96 | - logger.trace("Nombre d'objets recommandables {}", recommendableItemCount); | 120 | + logger.trace("C: Nombre d'objets recommandables {}", recommendableItemCount); |
97 | c1 = (float) recommendedItemCount / recommendableItemCount; | 121 | c1 = (float) recommendedItemCount / recommendableItemCount; |
98 | logger.trace("c1 {}", String.format(Locale.FRENCH, "%.3f", c1)); | 122 | logger.trace("c1 {}", String.format(Locale.FRENCH, "%.3f", c1)); |
99 | - logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount); | 123 | + logger.trace("C: Nombre de couples item-user recommandés {}", recommendedItemUserCount); |
100 | - logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount); | 124 | + logger.trace("C: Nombre de couples item-user recommandables {}", recommendableItemUserCount); |
101 | c2 = (float) recommendedItemUserCount / recommendableItemUserCount; | 125 | c2 = (float) recommendedItemUserCount / recommendableItemUserCount; |
102 | logger.trace("c2 {}", String.format(Locale.FRENCH, "%.3f", c2)); | 126 | logger.trace("c2 {}", String.format(Locale.FRENCH, "%.3f", c2)); |
103 | c3 = recommendedItemCount; | 127 | c3 = recommendedItemCount; |
... | @@ -106,8 +130,19 @@ public class PostprocessingExpert { | ... | @@ -106,8 +130,19 @@ public class PostprocessingExpert { |
106 | return new PostprocessingCoverage(c1,c2, c3); | 130 | return new PostprocessingCoverage(c1,c2, c3); |
107 | } | 131 | } |
108 | 132 | ||
133 | + protected PostprocessingPrecisionRecall computePrecisionRecall() { | ||
134 | + float precision; | ||
135 | + float recall; | ||
109 | 136 | ||
137 | + logger.trace("PR: nombre de recommandations annotées {}", validRecommendationCount); | ||
138 | + logger.trace("PR: nombre de recommandations {}", recommendedItemUserCount); | ||
139 | + precision = (float) validRecommendationCount / recommendedItemUserCount; | ||
140 | + logger.trace("PR: précision {}", String.format(Locale.FRENCH, "%.3f", precision)); | ||
141 | + logger.trace("PR: nombre d'associations annotées {}", annotatedItemUserCount); | ||
142 | + recall = (float) validRecommendationCount / annotatedItemUserCount; | ||
143 | + logger.trace("PR: rappel {}", String.format(Locale.FRENCH, "%.3f", recall)); | ||
110 | 144 | ||
111 | - | 145 | + return new PostprocessingPrecisionRecall(precision, recall); |
146 | + } | ||
112 | 147 | ||
113 | } | 148 | } | ... | ... |
1 | +package org.legrog.recommendation.postprocess; | ||
2 | + | ||
3 | + | ||
4 | +public class PostprocessingPrecisionRecall { | ||
5 | + private float precision; | ||
6 | + private float recall; | ||
7 | + | ||
8 | + public PostprocessingPrecisionRecall(float precision, float recall) { | ||
9 | + this.precision = precision; | ||
10 | + this.recall = recall; | ||
11 | + } | ||
12 | + | ||
13 | + public float getPrecision() { | ||
14 | + return precision; | ||
15 | + } | ||
16 | + | ||
17 | + public float getRecall() { | ||
18 | + return recall; | ||
19 | + } | ||
20 | +} |
... | @@ -11,10 +11,7 @@ import org.springframework.boot.ApplicationRunner; | ... | @@ -11,10 +11,7 @@ import org.springframework.boot.ApplicationRunner; |
11 | import org.springframework.stereotype.Component; | 11 | import org.springframework.stereotype.Component; |
12 | 12 | ||
13 | import java.io.*; | 13 | import java.io.*; |
14 | -import java.util.List; | 14 | +import java.util.*; |
15 | -import java.util.Locale; | ||
16 | -import java.util.Properties; | ||
17 | -import java.util.Set; | ||
18 | import java.util.stream.Collectors; | 15 | import java.util.stream.Collectors; |
19 | import java.util.stream.StreamSupport; | 16 | import java.util.stream.StreamSupport; |
20 | 17 | ||
... | @@ -33,14 +30,24 @@ public class PostprocessingRunner implements ApplicationRunner { | ... | @@ -33,14 +30,24 @@ public class PostprocessingRunner implements ApplicationRunner { |
33 | @Value("${ratingSample.filename}") | 30 | @Value("${ratingSample.filename}") |
34 | private String ratingSampleFilename; | 31 | private String ratingSampleFilename; |
35 | 32 | ||
33 | + @Value("${collectionAnnotated.filename}") | ||
34 | + private String collectionAnnotatedFilename; | ||
35 | + | ||
36 | + @Value("${ratingAnnotated.filename}") | ||
37 | + private String ratingAnnotatedFilename; | ||
38 | + | ||
36 | @Value("${recommandations.filename}") | 39 | @Value("${recommandations.filename}") |
37 | private String recommandationsFilename; | 40 | private String recommandationsFilename; |
38 | 41 | ||
39 | @Value("${coverage.filename}") | 42 | @Value("${coverage.filename}") |
40 | private String coverageFilename; | 43 | private String coverageFilename; |
41 | 44 | ||
45 | + @Value("${precisionRecall.filename}") | ||
46 | + private String precisionRecallFilename; | ||
47 | + | ||
42 | private Logger logger = LoggerFactory.getLogger(getClass()); | 48 | private Logger logger = LoggerFactory.getLogger(getClass()); |
43 | private String sampleFilename; | 49 | private String sampleFilename; |
50 | + private String annotatedFilename; | ||
44 | 51 | ||
45 | 52 | ||
46 | @Override | 53 | @Override |
... | @@ -49,11 +56,15 @@ public class PostprocessingRunner implements ApplicationRunner { | ... | @@ -49,11 +56,15 @@ public class PostprocessingRunner implements ApplicationRunner { |
49 | loadSampleFilename(); | 56 | loadSampleFilename(); |
50 | List<PostprocessingSample> samples = loadCsvSample(new File(dataDir, sampleFilename)); | 57 | List<PostprocessingSample> samples = loadCsvSample(new File(dataDir, sampleFilename)); |
51 | List<PostprocessingSample> recommendations = loadCsvSample(new File(dataDir, recommandationsFilename)); | 58 | List<PostprocessingSample> recommendations = loadCsvSample(new File(dataDir, recommandationsFilename)); |
59 | + List<PostprocessingSample> annotated = loadCsvSample(new File(dataDir, annotatedFilename)); | ||
52 | 60 | ||
53 | - PostprocessingExpert expert = new PostprocessingExpert(samples, recommendations); | 61 | + PostprocessingExpert expert = new PostprocessingExpert(samples, recommendations, annotated); |
62 | + expert.analyze(); | ||
54 | PostprocessingCoverage coverage = expert.getCoverage(); | 63 | PostprocessingCoverage coverage = expert.getCoverage(); |
64 | + PostprocessingPrecisionRecall precisionRecall = expert.getPrecisionRecall(); | ||
55 | 65 | ||
56 | writeCsvCoverage(coverage, dataDir, coverageFilename); | 66 | writeCsvCoverage(coverage, dataDir, coverageFilename); |
67 | + writeCsvPrecisionRecall(precisionRecall, dataDir, precisionRecallFilename); | ||
57 | } | 68 | } |
58 | 69 | ||
59 | private void writeCsvCoverage(PostprocessingCoverage coverage, String dataDir, String coverageFilename) throws PostprocessingException { | 70 | private void writeCsvCoverage(PostprocessingCoverage coverage, String dataDir, String coverageFilename) throws PostprocessingException { |
... | @@ -69,6 +80,19 @@ public class PostprocessingRunner implements ApplicationRunner { | ... | @@ -69,6 +80,19 @@ public class PostprocessingRunner implements ApplicationRunner { |
69 | 80 | ||
70 | } | 81 | } |
71 | 82 | ||
83 | + private void writeCsvPrecisionRecall(PostprocessingPrecisionRecall precisionRecall, String dataDir, String precisionRecallFilename) throws PostprocessingException { | ||
84 | + try { | ||
85 | + CSVPrinter csvPrinter = new CSVPrinter(new FileWriter(new File(dataDir, precisionRecallFilename)), | ||
86 | + CSVFormat.TDF.withHeader("Precision", "Recall")); | ||
87 | + csvPrinter.printRecord(String.format(Locale.FRENCH, "%.3f", precisionRecall.getPrecision()), | ||
88 | + String.format(Locale.FRENCH, "%.3f", precisionRecall.getRecall())); | ||
89 | + csvPrinter.close(); | ||
90 | + } catch (IOException e) { | ||
91 | + throw new PostprocessingException("Can't write coverage file " + dataDir + precisionRecallFilename, e); | ||
92 | + } | ||
93 | + | ||
94 | + } | ||
95 | + | ||
72 | /** | 96 | /** |
73 | * read csv (TDF) file and map it to a list of PostprocessingSample | 97 | * read csv (TDF) file and map it to a list of PostprocessingSample |
74 | * | 98 | * |
... | @@ -77,6 +101,10 @@ public class PostprocessingRunner implements ApplicationRunner { | ... | @@ -77,6 +101,10 @@ public class PostprocessingRunner implements ApplicationRunner { |
77 | * @throws PostprocessingException | 101 | * @throws PostprocessingException |
78 | */ | 102 | */ |
79 | private List<PostprocessingSample> loadCsvSample(File file) throws PostprocessingException { | 103 | private List<PostprocessingSample> loadCsvSample(File file) throws PostprocessingException { |
104 | + if (!file.exists() || file.isDirectory()) { | ||
105 | + return new LinkedList<>(); | ||
106 | + } | ||
107 | + | ||
80 | try (Reader in = new InputStreamReader(new FileInputStream(file))) { | 108 | try (Reader in = new InputStreamReader(new FileInputStream(file))) { |
81 | Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in); | 109 | Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in); |
82 | 110 | ||
... | @@ -109,12 +137,15 @@ public class PostprocessingRunner implements ApplicationRunner { | ... | @@ -109,12 +137,15 @@ public class PostprocessingRunner implements ApplicationRunner { |
109 | logger.trace("ratings {}", properties.getProperty("ratings")); | 137 | logger.trace("ratings {}", properties.getProperty("ratings")); |
110 | if (Boolean.parseBoolean(properties.getProperty("ratings"))) { | 138 | if (Boolean.parseBoolean(properties.getProperty("ratings"))) { |
111 | sampleFilename = ratingSampleFilename; | 139 | sampleFilename = ratingSampleFilename; |
140 | + annotatedFilename = ratingAnnotatedFilename; | ||
112 | } else { | 141 | } else { |
113 | sampleFilename = collectionSampleFilename; | 142 | sampleFilename = collectionSampleFilename; |
143 | + annotatedFilename = collectionAnnotatedFilename; | ||
114 | } | 144 | } |
115 | } else { | 145 | } else { |
116 | // by default, takes collection | 146 | // by default, takes collection |
117 | sampleFilename = collectionSampleFilename; | 147 | sampleFilename = collectionSampleFilename; |
148 | + annotatedFilename = collectionAnnotatedFilename; | ||
118 | } | 149 | } |
119 | } catch (IOException e) { | 150 | } catch (IOException e) { |
120 | throw new PostprocessingException("Can't read properties file " + parametersFilename, e); | 151 | throw new PostprocessingException("Can't read properties file " + parametersFilename, e); | ... | ... |
... | @@ -9,6 +9,15 @@ public class PostprocessingSample { | ... | @@ -9,6 +9,15 @@ public class PostprocessingSample { |
9 | this.itemId = itemId; | 9 | this.itemId = itemId; |
10 | } | 10 | } |
11 | 11 | ||
12 | + public boolean equals(Object obj) { | ||
13 | + if (obj instanceof PostprocessingSample) { | ||
14 | + PostprocessingSample postprocessingSample = (PostprocessingSample) obj; | ||
15 | + return this.itemId == postprocessingSample.getItemId() && this.userId == postprocessingSample.getUserId(); | ||
16 | + } else { | ||
17 | + return false; | ||
18 | + } | ||
19 | + } | ||
20 | + | ||
12 | public Long getUserId() { | 21 | public Long getUserId() { |
13 | return userId; | 22 | return userId; |
14 | } | 23 | } | ... | ... |
-
Please register or login to post a comment