ACMTrainingSetBuilder.java
package edu.odu.cs.cs350;
import weka.core.*;
import weka.core.stemmers.IteratedLovinsStemmer;
import weka.core.tokenizers.AlphabeticTokenizer;
import weka.core.stopwords.WordsFromFile;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.StringToWordVector;
// import weka.classifiers.Evaluation;
import weka.classifiers.functions.SMO;
import weka.classifiers.meta.FilteredClassifier;
import java.util.*;
import org.apache.pdfbox.Loader;
import org.apache.pdfbox.pdmodel.PDDocument;
import org.apache.pdfbox.text.PDFTextStripper;
import java.io.*;
import java.nio.charset.StandardCharsets;
import edu.odu.cs.cs350.categorization.trainingData.TrainingData;
/**
* Build a training set based on TrainingData PDFs and train a FilteredClassifier.
*/
public class ACMTrainingSetBuilder {
/**
* Main method to build training set and train classifier.
* @param inputArguments program arguments (unused)
* @throws Exception if files cannot be read or written
*/
public static void main(String[] inputArguments) throws Exception {
// Read listing and categories
List<String> docPaths = getDocumentListing();
List<String> possibleCategories = getACMClasses(docPaths);
// Build raw Instances with a string `text` attribute and nominal `class`
ArrayList<Attribute> inputAttributes = new ArrayList<>();
Attribute textAttribute = new Attribute("text", (List<String>) null);
Attribute categoryAttribute = new Attribute("class", possibleCategories);
inputAttributes.add(textAttribute);
inputAttributes.add(categoryAttribute);
Instances raw = new Instances("ACMDocsRaw", inputAttributes, docPaths.size());
raw.setClassIndex(1);
addStringInstancesFromRepository(raw, possibleCategories, docPaths);
StringToWordVector filteredTrainingData = buildTrainingFilter(raw);
Instances trainingSet = Filter.useFilter(raw, filteredTrainingData);
// Train an SMO classifier wrapped in a FilteredClassifier (so the same
// StringToWordVector filter is applied at classification time).
SMO learningMachine = new SMO();
FilteredClassifier filteredClassifier = new FilteredClassifier();
filteredClassifier.setFilter(filteredTrainingData);
filteredClassifier.setClassifier(learningMachine);
filteredClassifier.buildClassifier(raw);
writeModels(raw, trainingSet, filteredClassifier);
// evaluate cross validation for testing only
// Evaluation filteredClassifierEvaluator = new Evaluation(trainingSet);
// filteredClassifierEvaluator.crossValidateModel(filteredClassifier, trainingSet, 5, new Random(1));
// double evaluationAccuracy = filteredClassifierEvaluator.pctCorrect();
// System.out.printf("Cross Validation Accuracy: %.2f%%%n", evaluationAccuracy);
}
/**
* Write trained models to disk.
* @param raw raw Instances header
* @param trainingSet vectorized training Instances
* @param trainedClassifier trained FilteredClassifier
* @throws Exception if files cannot be written
*/
public static void writeModels(Instances raw, Instances trainingSet, FilteredClassifier trainedClassifier) throws Exception {
// Save training ARFF (not used by the current classifier since it uses the .model files, will need to ask the professor if this is OK)
try (BufferedWriter writer = new BufferedWriter(new FileWriter("acm_training.arff"))) {
writer.write(trainingSet.toString());
}
SerializationHelper.write("acm_filtered.model", trainedClassifier);
SerializationHelper.write("acm_raw_header.model", raw);
}
/**
* Gets PDFs from Maven repo, extracts text, and adds Instances.
* @param raw empty Instances object to fill
* @param possibleCategories list of valid ACM categories
* @param docPaths list of document paths within the Maven repo
* @throws IOException if a file can't be read
*/
public static void addStringInstancesFromRepository(Instances raw,
List<String> possibleCategories,
List<String> docPaths) throws IOException {
if (docPaths == null) {
System.err.println("Document path list is null");
return;
}
Attribute textAttribute = raw.attribute("text");
Attribute categoryAttribute = raw.classAttribute();
for (String resourceRelative : docPaths) {
if (resourceRelative == null || resourceRelative.trim().isEmpty()) continue;
String documentText = getTextFromDocument(resourceRelative);
String category = determineCategoryFromPath(resourceRelative, possibleCategories);
DenseInstance instance = new DenseInstance(2);
instance.setDataset(raw);
// register string value and set it
instance.setValue(textAttribute, textAttribute.addStringValue(documentText));
instance.setValue(categoryAttribute, category);
raw.add(instance);
}
}
/**
* Determine ACM category from document path within repo.
* @param resourceRelative path to the document relative to the TrainingData resource path
* @param possibleCategories list of valid ACM categories
* @return determined category or null if not recognized
*/
public static String determineCategoryFromPath(String resourceRelative, List<String> possibleCategories) {
String category = null;
int slash = resourceRelative.indexOf('/');
if (slash > 0) {
category = resourceRelative.substring(0, slash);
}
if (category == null || !possibleCategories.contains(category)) {
System.err.println("Warning: category '" + category + "' not recognized for resource " + resourceRelative + "; skipping.");
return null;
}
return category;
}
/**
* Extract text from a PDF document in the TrainingData repository.
* @param resourceRelative path to the document relative to the TrainingData resource path
* @return extracted text, or null if extraction failed
* @throws IOException if the document cannot be read
*/
public static String getTextFromDocument(String resourceRelative) throws IOException{
String documentText = null;
try (InputStream inputDocument = TrainingData.class.getResourceAsStream(TrainingData.resourcePath + resourceRelative)) {
if (inputDocument == null) {
System.err.println("Warning: resource not found: " + TrainingData.resourcePath + resourceRelative);
return null;
}
byte[] inputBuffer = inputDocument.readAllBytes();
try (PDDocument doc = Loader.loadPDF(inputBuffer)) {
PDFTextStripper stripper = new PDFTextStripper();
documentText = stripper.getText(doc);
} catch (IOException e) {
System.err.println("Warning: failed to parse PDF " + resourceRelative + " : " + e.getMessage());
return null;
}
}
return documentText;
}
/**
* Build a StringToWordVector filter for training and classification.
* @param raw raw Instances to set input format
* @return configured StringToWordVector filter
* @throws Exception if filter cannot be initialized
*/
public static StringToWordVector buildTrainingFilter(Instances raw) throws Exception{
StringToWordVector trainingFilter = new StringToWordVector();
trainingFilter.setWordsToKeep(5000); // below 5000 seems to drop accuracy but above makes no difference
trainingFilter.setMinTermFreq(1); // i tried messing around with this but it made 0 difference
trainingFilter.setLowerCaseTokens(true);
trainingFilter.setTFTransform(false); // setting this to true decreases the validation accuracy slightly (by like 1%)
trainingFilter.setIDFTransform(false); // same as above
trainingFilter.setOutputWordCounts(true); // have not tested this
trainingFilter.setAttributeNamePrefix("count_"); // don't change
// Use AlphabeticTokenizer to avoid punctuation and other invalid characters from PDFs
AlphabeticTokenizer alphabeticTokenizer = new AlphabeticTokenizer();
trainingFilter.setTokenizer(alphabeticTokenizer);
// this stemmer is kind of aggressive but getting porter stemmer working was an experience on its own
IteratedLovinsStemmer aggressiveStemmer = new IteratedLovinsStemmer();
trainingFilter.setStemmer(aggressiveStemmer);
// Load custom stopwords from training/stopwords.txt
WordsFromFile ignoredWords = new WordsFromFile();
ignoredWords.setStopwords(new File("stopwords.txt"));
trainingFilter.setStopwordsHandler(ignoredWords);
trainingFilter.setInputFormat(raw);
return trainingFilter;
}
/**
* Get list of document paths from TrainingData repository.
* @return list of document paths including category
* @throws IOException if the listing cannot be read
*/
public static List<String> getDocumentListing() throws IOException{
// get list of PDFs from TrainingData repo
InputStream directoryIn = TrainingData.class
.getResourceAsStream(TrainingData.resourcePath
+ TrainingData.directory);
List<String> docPaths = new ArrayList<>();
try (BufferedReader directoryReader = new BufferedReader(new InputStreamReader(directoryIn, StandardCharsets.UTF_8))) {
String line;
while ((line = directoryReader.readLine()) != null) {
line = line.trim();
if (!line.isEmpty()) {
docPaths.add(line);
}
}
}
return docPaths;
}
/**
* Get list of ACM categories from document paths (using the subfolder names).
* @param docPaths list of document paths from repository
* @return list of unique ACM categories
*/
public static List<String> getACMClasses(List<String> docPaths) {
List<String> categories = new ArrayList<>();
for (String docPath : docPaths) {
int directorySeparatorIndex = docPath.indexOf('/');
if ((directorySeparatorIndex > 0)) {
String category = docPath.substring(0, directorySeparatorIndex);
if (!categories.contains(category)) {
categories.add(category);
}
}
}
return categories;
}
}