Статьи

API прогнозирования: машинное обучение от Google

Вступление

Одним из захватывающих API из более чем 50 API, предлагаемых Google, является API Prediction. Он обеспечивает сопоставление с образцом и возможности машинного обучения, такие как рекомендации или категоризация. Это понятие похоже на возможности машинного обучения, которые мы видим в других решениях (например, в Apache Mahout): мы можем обучить систему с помощью набора обучающих данных, а затем приложения, основанные на Prediction API, могут порекомендовать («предсказать»), что продукты, которые могут понравиться пользователю, или они могут классифицировать спам и т. д. В этом посте мы рассмотрим пример того, как классифицировать SMS-сообщения — будь то спам или ценные тексты («радиолюбители»).

Использование API прогнозирования

Чтобы использовать API прогнозирования, необходимо включить службу через
консоль API Google. Для загрузки данных обучения API прогнозирования также требуется Google Cloud Storage. Набор данных, используемый в этом посте, взят из репозитория машинного обучения UCI. Репозиторий UCI Machine Learning имеет 235 общедоступных наборов данных, этот пост основан на
наборе данных SMS Spam Collections.

Чтобы сначала загрузить данные обучения, нам нужно создать корзину в облачном хранилище Google. В консоли Google API нам нужно нажать на Google Cloud Storage, а затем на Google Cloud Storage Manager: это откроет веб-страницу, где мы сможем создавать новые сегменты и загружать или удалять файлы.GoogleStorage2

Файл UCI SMS Spam Collection не подходит, как и для Prediction API, его необходимо преобразовать в следующий формат (категории — ham / spam — необходимо указать в кавычках, а также текст SMS):

«Ham» «Идите до точки jurong, сумасшедший .. Доступно только в Bugis N Great World» шведский стол «… Cine там получил Amore Wat …»

GoogleStorage4

Google Prediction API предлагает несколько команд, которые могут быть вызваны через интерфейс REST. Самым простым способом тестирования Prediction API является использование Prediction API explorer.GooglePrediction1

Как только данные обучения будут доступны в облачном хранилище Google, мы можем приступить к обучению системе машинного обучения на основе API прогнозирования. Чтобы начать обучение нашей модели, нам нужно запуститьвести на прогнозирование.trainedmodels.insert. Все команды требуют аутентификации, она основана на стандарте OAuth 2.0.

GooglePrediction2

В меню вставки нам нужно указать поля, которые мы хотим включить в ответ. В теле запроса нам нужно определить идентификатор (он будет использоваться в качестве ссылки на модель в командах, которые будут использоваться позже), storageDataLocation, куда мы загружаем данные обучения (путь к облачному хранилищу Google) и modelType (может регрессия или классификация, для фильтрации спама это классификация):

GooglePrediction-SpamInsert1

Обучение длится некоторое время, мы можем проверить его состояние с помощью командыpretion.trainedmodels.get. Поле состояния будет «РАБОТАЕТ» и будет изменено на «ВЫПОЛНЕНО» после завершения обучения.

GooglePrediction-SpamGet1
GooglePrediction-SpamGet2

Теперь мы готовы запустить наш тест для системы машинного обучения и собираемся определить, является ли данный текст спамом или ветчиной. API-интерфейс Prediction для этого действия —pretion.trainedmodels.predict. В поле id мы должны обратиться к идентификатору, который мы определили для командыpretion.trainedmodels.insert (bighadoop-00001), и нам также нужно указать тело запроса — input будет csvInstance, а затем мы введем нужный текст классифицироваться (например, «Бесплатный вход»)

GooglePrediction-SpamPredict1

Затем система возвращается с категорией (спам) и оценкой (0,822158 для спама, 0,177842 для ветчины):

GooglePrediction-SpamPredict2

Библиотеки API Google Prediction

Google также предлагает специальный пример приложения, включающий весь код, необходимый для его запуска в Google App Engine. Он называется Try-Prediction, а код написан на Python, а также на Java. Приложение может быть протестировано на
http://try-prediction.appspot.com. Например, если мы введем цитату для модели определения языка от Нильса Бора: «Предсказание очень сложно, особенно если речь идет о будущем», он вернет, что это, скорее всего, текст на английском языке (54,4%) ,
TryPrediction

Ключевая часть кода Python находится в Foret.py: 

class PredictAPI(webapp.RequestHandler):
  '''This class handles Ajax prediction requests, i.e. not user initiated
     web sessions but remote procedure calls initiated from the Javascript
     client code running the browser.
  '''




  def get(self):
    try:
      # Read server-side OAuth 2.0 credentials from datastore and
      # raise an exception if credentials not found.
      credentials = StorageByKeyName(CredentialsModel, USER_AGENT, 
                                    'credentials').locked_get()
      if not credentials or credentials.invalid:
        raise Exception('missing OAuth 2.0 credentials')




      # Authorize HTTP session with server credentials and obtain  
      # access to prediction API client library.
      http = credentials.authorize(httplib2.Http())
      service = build('prediction', 'v1.4', http=http)
      papi = service.trainedmodels()




      # Read and parse JSON model description data.
      models = parse_json_file(MODELS_FILE)




      # Get reference to user's selected model.
      model_name = self.request.get('model')
      model = models[model_name]




      # Build prediction data (csvInstance) dynamically based on form input.
      vals = []
      for field in model['fields']:
        label = field['label']
        val = str(self.request.get(label))
        vals.append(val)
      body = {'input' : {'csvInstance' : vals }}
      logging.info('model:' + model_name + ' body:' + str(body))




      # Make a prediction and return JSON results to Javascript client.
      ret = papi.predict(id=model['model_id'], body=body).execute()
      self.response.out.write(json.dumps(ret))




    except Exception, err:
      # Capture any API errors here and pass response from API back to
      # Javascript client embedded in a special error indication tag.
      err_str = str(err)
      if err_str[0:len(ERR_TAG)] != ERR_TAG:
        err_str = ERR_TAG + err_str + ERR_END
      self.response.out.write(err_str)

Java-версия веб-приложения Prediction выглядит следующим образом:

public class PredictServlet extends HttpServlet {




  @Override
  protected void doGet(HttpServletRequest request,
                       HttpServletResponse response) throws ServletException, 
                                                            IOException {
    Entity credentials = null;
    try {
      // Retrieve server credentials from app engine datastore.
      DatastoreService datastore = 
        DatastoreServiceFactory.getDatastoreService();
      Key credsKey = KeyFactory.createKey("Credentials", "Credentials");
      credentials = datastore.get(credsKey);
    } catch (EntityNotFoundException ex) {
      // If can't obtain credentials, send exception back to Javascript client.
      response.setContentType("text/html");
      response.getWriter().println("exception: " + ex.getMessage());
    }




    // Extract tokens from retrieved credentials.
    AccessTokenResponse tokens = new AccessTokenResponse();
    tokens.accessToken = (String) credentials.getProperty("accessToken");
    tokens.expiresIn = (Long) credentials.getProperty("expiresIn");
    tokens.refreshToken = (String) credentials.getProperty("refreshToken");
    String clientId = (String) credentials.getProperty("clientId");
    String clientSecret = (String) credentials.getProperty("clientSecret");
    tokens.scope = IndexServlet.scope;




    // Set up the HTTP transport and JSON factory
    HttpTransport httpTransport = new NetHttpTransport();
    JsonFactory jsonFactory = new JacksonFactory();




    // Get user requested model, if specified.
    String model_name = request.getParameter("model");




    // Parse model descriptions from models.json file.
    Map models = 
      IndexServlet.parseJsonFile(IndexServlet.modelsFile);




    // Setup reference to user specified model description.
    Map selectedModel = 
      (Map) models.get(model_name);
    
    // Obtain model id (the name under which model was trained), 
    // and iterate over the model fields, building a list of Strings
    // to pass into the prediction request.
    String modelId = (String) selectedModel.get("model_id");
    List params = new ArrayList();
    List<Map > fields = 
      (List<Map >) selectedModel.get("fields");
    for (Map field : fields) {
      // This loop is populating the input csv values for the prediction call.
      String label = field.get("label");
      String value = request.getParameter(label);
      params.add(value);
    }




    // Set up OAuth 2.0 access of protected resources using the retrieved
    // refresh and access tokens, automatically refreshing the access token 
    // whenever it expires.
    GoogleAccessProtectedResource requestInitializer = 
      new GoogleAccessProtectedResource(tokens.accessToken, httpTransport, 
                                        jsonFactory, clientId, clientSecret, 
                                        tokens.refreshToken);




    // Now populate the prediction data, issue the API call and return the
    // JSON results to the Javascript AJAX client.
    Prediction prediction = new Prediction(httpTransport, requestInitializer, 
                                           jsonFactory);
    Input input = new Input();
    InputInput inputInput = new InputInput();
    inputInput.setCsvInstance(params);
    input.setInput(inputInput);
    Output output = 
      prediction.trainedmodels().predict(modelId, input).execute();
    response.getWriter().println(output.toPrettyString());
  }
}

Помимо поддержки Python и Java, Google также предлагает библиотеки .NET, Objective-C, Ruby, Go, JavaScript, PHP и т. Д. Для API прогнозирования.