Jean-Francois Leveque

Pré-traitement : tirage aléatoire d'un pourcentage indiqué d'éléments annotés.

......@@ -14,4 +14,37 @@
<artifactId>grog-recommendation-preprocess</artifactId>
<version>3.0-SNAPSHOT</version>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
<version>1.3</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<version>1.5.2.RELEASE </version>
<executions>
<execution>
<goals>
<goal>repackage</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
\ No newline at end of file
......
package org.legrog.recommendation.preprocess;
public class AssociationElement {
private Long userId;
private Long itemId;
public AssociationElement(Long userId, Long itemId) {
this.userId = userId;
this.itemId = itemId;
}
public Long getUserId() {
return userId;
}
public Long getItemId() {
return itemId;
}
}
package org.legrog.recommendation.preprocess;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class PreprocessingApplication {
public static void main(String[] args) {
SpringApplication.run(PreprocessingApplication.class, args);
}
}
package org.legrog.recommendation.preprocess;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
@Component
public class PreprocessingRunner implements ApplicationRunner {
Logger logger = LoggerFactory.getLogger(getClass());
@Value("${parameters.filename}")
private String parametersFilename;
@Value("${data.dir}")
private String dataDir;
@Value("${collectionComplete.filename}")
private String collectionCompleteFilename;
@Value("${ratingComplete.filename}")
private String ratingCompleteFilename;
@Value("${collectionSample.filename}")
private String collectionSampleFilename;
@Value("${ratingSample.filename}")
private String ratingSampleFilename;
@Value("${collectionAnnotated.filename}")
private String collectionAnnotatedFilename;
@Value("${ratingAnnotated.filename}")
private String ratingAnnotatedFilename;
private String completeFilename;
private String sampleFilename;
private String annontatedFilename;
private Boolean ratings;
private int annotatePercent;
@Override
public void run(ApplicationArguments applicationArguments) throws Exception {
loadParameters();
setFilenames();
List<AssociationElement> associationElements = loadAssociationElements(new File(dataDir, completeFilename));
List<Integer> annotateIndexes = chooseAnnotated(associationElements.size());
writeSampleAndAnnotated(new File(dataDir, sampleFilename), new File(dataDir, annontatedFilename), annotateIndexes, associationElements);
}
private List<Integer> chooseAnnotated(int size) {
List<Integer> annotatedChosen = new ArrayList<>();
Random random = new Random();
Integer randomInteger;
while (annotatedChosen.size() <= size * annotatePercent / 100.0) {
randomInteger = new Integer(random.nextInt(size));
if (!annotatedChosen.contains(randomInteger)) {
annotatedChosen.add(randomInteger);
}
}
return annotatedChosen;
}
private void writeSampleAndAnnotated(File sampleFile, File annotatedFile, List<Integer> annotateIndexes, List<AssociationElement> associationElements) throws PreprocessingException {
try {
AssociationElement associationElement;
if (ratings) {
RatingElement ratingElement;
CSVFormat ratingsFormat = CSVFormat.TDF.withHeader("itemId", "userId", "rating");
CSVPrinter samplePrinter = new CSVPrinter(new FileWriter(sampleFile), ratingsFormat);
CSVPrinter annotatedPrinter = new CSVPrinter(new FileWriter(annotatedFile), ratingsFormat);
for (int i = 0; i < associationElements.size(); i++) {
ratingElement = (RatingElement) associationElements.get(i);
Integer index = new Integer(i);
if (annotateIndexes.contains(index)) {
annotatedPrinter.printRecord(ratingElement.getItemId(), ratingElement.getUserId(), ratingElement.getRating());
} else {
samplePrinter.printRecord(ratingElement.getItemId(), ratingElement.getUserId(), ratingElement.getRating());
}
}
samplePrinter.close();
annotatedPrinter.close();
} else {
CSVFormat collectionsFormat = CSVFormat.TDF.withHeader("itemId", "userId");
CSVPrinter samplePrinter = new CSVPrinter(new FileWriter(sampleFile), collectionsFormat);
CSVPrinter annotatedPrinter = new CSVPrinter(new FileWriter(annotatedFile), collectionsFormat);
for (int i = 0; i < associationElements.size(); i++) {
associationElement = associationElements.get(i);
Integer index = new Integer(i);
if (annotateIndexes.contains(index)) {
annotatedPrinter.printRecord(associationElement.getItemId(), associationElement.getUserId());
} else {
samplePrinter.printRecord(associationElement.getItemId(), associationElement.getUserId());
}
}
samplePrinter.close();
annotatedPrinter.close();
}
} catch (IOException e) {
throw new PreprocessingException("Can't write sample or annotated file " + dataDir + sampleFilename + " / " + annontatedFilename, e);
}
}
private List<AssociationElement> loadAssociationElements(File file) throws PreprocessingException {
try (Reader in = new InputStreamReader(new FileInputStream(file))) {
Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in);
if (ratings) {
return StreamSupport.stream(records.spliterator(), false)
.map((record) -> new RatingElement(
Long.parseLong(record.get("userId")),
Long.parseLong(record.get("itemId")),
Integer.parseInt(record.get("rating")))
)
.collect(Collectors.toList());
} else {
return StreamSupport.stream(records.spliterator(), false)
.map((record) -> new AssociationElement(
Long.parseLong(record.get("userId")),
Long.parseLong(record.get("itemId")))
)
.collect(Collectors.toList());
}
} catch (IOException e) {
throw new PreprocessingException("Can't read CSV file " + file, e);
}
}
private void setFilenames() {
if (ratings) {
completeFilename = ratingCompleteFilename;
sampleFilename = ratingSampleFilename;
annontatedFilename = ratingAnnotatedFilename;
} else {
completeFilename = collectionCompleteFilename;
sampleFilename = collectionSampleFilename;
annontatedFilename = collectionAnnotatedFilename;
}
}
private void loadParameters() throws PreprocessingException {
try (InputStream in = new FileInputStream(new File(dataDir, parametersFilename))) {
Properties properties = new Properties();
properties.load(in);
if (properties.containsKey("ratings")) {
logger.trace("ratings {}", properties.getProperty("ratings"));
if (Boolean.parseBoolean(properties.getProperty("ratings"))) {
ratings = Boolean.TRUE;
} else {
ratings = Boolean.FALSE;
}
} else {
// by default, takes collection
ratings = Boolean.FALSE;
}
if (properties.containsKey("annotatePercent")) {
annotatePercent = Integer.parseInt(properties.getProperty("annotatePercent"));
} else {
// default top size is 10
annotatePercent = 1;
}
} catch (IOException e) {
throw new PreprocessingException("Can't read parameters properties file " + dataDir + parametersFilename, e);
}
}
private class PreprocessingException extends Exception {
public PreprocessingException() {
super();
}
public PreprocessingException(String message) {
super(message);
}
public PreprocessingException(String message, Throwable cause) {
super(message, cause);
}
public PreprocessingException(Throwable cause) {
super(cause);
}
protected PreprocessingException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}
}
}
package org.legrog.recommendation.preprocess;
public class RatingElement extends AssociationElement {
private Integer rating;
public RatingElement(Long userId, Long itemId, Integer rating) {
super(userId, itemId);
this.rating = rating;
}
public Integer getRating() {
return rating;
}
}
parameters.filename=${parameters.filename}
collectionSample.filename=${collectionSample.filename}
ratingSample.filename=${ratingSample.filename}
recommandations.filename=${recommandations.filename}
coverage.filename=${coverage.filename}
data.dir=dumb/
collectionComplete.filename=${collectionComplete.filename}
ratingComplete.filename=${ratingComplete.filename}
collectionAnnotated.filename=${collectionAnnotated.filename}
ratingAnnotated.filename=${ratingAnnotated.filename}
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<!-- encoders are assigned the type
ch.qos.logback.classic.encoder.PatternLayoutEncoder by default -->
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<logger name="org.legrog" level="DEBUG"/>
<logger name="org.legrog.recommendation.preprocess" level="TRACE"/>
<root level="warn">
<appender-ref ref="STDOUT" />
</root>
</configuration>
\ No newline at end of file