更多实时更新的个人学习笔记分享,请关注:
知乎:https://www.zhihu.com/people/yuquanle/columns
微信订阅号:AI小白入门
ID: StudyForAI

<center> </center>

Flair文档嵌入教程

  • 文档嵌入与单词嵌入不同之处在于它们为您提供了一个嵌入整个文本的内容,而文字嵌入则为您提供了嵌入单个单词的内容。

  • 所有文档嵌入类都继承自DocumentEmbeddings类,并实现embed()方法,您需要调用该方法来嵌入文本。

  • 生成的所有嵌入都是Pytorch向量,因此它们可以立即用于训练和微调。

教程地址:https://github.com/zalandoresearch/flair/blob/master/resources/docs/TUTORIAL_5_DOCUMENT_EMBEDDINGS.md


文档嵌入

文档嵌入是通过嵌入文档中的所有单词创建的。 目前,我们有两种不同的方法来从字嵌入列表中获取文档嵌入。

  • Pooling
  1. 第一种方法计算文档中所有字嵌入的池化操作。 默认操作是’mean’,它给出了句子中所有单词的平均值。 将得到的嵌入作为文档嵌入。
  2. 要创建平均文档嵌入,只需先创建任意数量的TokenEmbeddings并将它们放入列表中。 然后,使用此TokenEmbeddings列表启动DocumentMeanEmbeddings。 因此,如果要使用GloVe嵌入与CharLMEmbeddings一起创建文档嵌入,请使用以下代码:
from flair.embeddings import WordEmbeddings, CharLMEmbeddings, DocumentPoolEmbeddings

# initialize the word embeddings
glove_embedding = WordEmbeddings('glove')
charlm_embedding_forward = CharLMEmbeddings('news-forward')
charlm_embedding_backward = CharLMEmbeddings('news-backward')

# initialize the document embeddings
document_embeddings = DocumentPoolEmbeddings([glove_embedding,
                                              charlm_embedding_backward,
                                              charlm_embedding_forward])

# create an example sentence
sentence = Sentence('The grass is green .')

# embed the sentence with our document embedding
document_embeddings.embed(sentence)

# now check out the embedded sentence.
print(sentence.get_embedding())

这将打印出文档的嵌入。 由于文档嵌入源自单词嵌入,因此其维度取决于您正在使用的单词嵌入的维度。

  • LSTM
  1. 第二种方法使用LSTM创建DocumentEmbeddings。 LSTM将文档中每个标记的单词嵌入作为输入,并将其最后输出状态作为文档嵌入提供。

  2. 通过传递单词嵌入列表来启动DocumentLSTMEmbeddings。

  3. 嵌入维度取决于您使用的隐藏状态的数量以及LSTM是否是双向的。

from flair.embeddings import WordEmbeddings, DocumentLSTMEmbeddings

glove_embedding = WordEmbeddings('glove')

document_embeddings = DocumentLSTMEmbeddings([glove_embedding])

# create an example sentence
sentence = Sentence('The grass is green .')

# embed the sentence with our document embedding
document_embeddings.embed(sentence)

# now check out the embedded sentence.
print(sentence.get_embedding())