一枚NLPer小菜鸡

bilstm+crf(batch)的实现

batch版本的条件随机场

在上一篇文章中我们按照了 pytorch的官方教程复现了 简单版本的 条件随机场。基本了解了其中的代码实现。但是我们可以看到上一篇中的 crf需要反复的条件循环,同时也没有支持批处理的操作,如果实际应用的话,速度应该会慢很多,因此,在这里,我们实现了Batch版本的条件随机场。to be honest, batch版本相比较 傻瓜版本的实现有一丢丢的复杂,尤其是需要考虑大量的矩阵并行的操作,十分伤脑筋。还好,我参考了batch lstm+crf代码并且认真的剖析其中的细节后,基本弄清了其中的实现细节。如果你对 这部分代码感兴趣,可以参考我下面的代码跑起来学一学。在这部分代码中,一些核心地方给出了注释,但是可能仍然不够清晰,希望你可以自己画画图搞清楚这些细节。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import torch.optim as optim
from xarray import Dataset
from torch.utils.data import Dataset,DataLoader

class BiLSTM(nn.Module):
def __init__(self, vocab_size, tagset, embedding_dim, hidden_dim,
num_layers, bidirectional, dropout, pretrained=None):
super(BiLSTM, self).__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.tagset_size = len(tagset)
self.bidirectional = bidirectional
self.num_layers = num_layers
self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
if pretrained is not None:
self.word_embeds = nn.Embedding.from_pretrained(pretrained)
self.lstm = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim // 2 if bidirectional else hidden_dim,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
batch_first=True,
)
self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
self.hidden = None

def init_hidden(self, batch_size, device):
init_hidden_dim = self.hidden_dim // 2 if self.bidirectional else self.hidden_dim
init_first_dim = self.num_layers * 2 if self.bidirectional else self.num_layers
self.hidden = (
torch.randn(init_first_dim, batch_size, init_hidden_dim).to(device),
torch.randn(init_first_dim, batch_size, init_hidden_dim).to(device)
)

def repackage_hidden(self, hidden):
"""Wraps hidden states in new Tensors, to detach them from their history."""
if isinstance(hidden, torch.Tensor):
return hidden.detach_()
else:
return tuple(self.repackage_hidden(h) for h in hidden)

def forward(self, batch_input, batch_input_lens, batch_mask):
batch_size, padding_length = batch_input.size()
batch_input = self.word_embeds(batch_input) # size: #batch * padding_length * embedding_dim
batch_input = rnn_utils.pack_padded_sequence(
batch_input, batch_input_lens, batch_first=True)
batch_output, self.hidden = self.lstm(batch_input, self.hidden)
self.repackage_hidden(self.hidden)
batch_output, _ = rnn_utils.pad_packed_sequence(batch_output, batch_first=True)

batch_output = batch_output.contiguous().view(batch_size * padding_length, -1)

batch_output = batch_output[batch_mask, ...]

out = self.hidden2tag(batch_output)
return out

def neg_log_likelihood(self, batch_input, batch_input_lens, batch_mask, batch_target):
loss = nn.CrossEntropyLoss(reduction='mean')
feats = self(batch_input, batch_input_lens, batch_mask)
batch_target = torch.cat(batch_target, 0)
return loss(feats, batch_target)

def predict(self, batch_input, batch_input_lens, batch_mask):
feats = self(batch_input, batch_input_lens, batch_mask)
val, pred = torch.max(feats, 1)
return pred


class CRF(nn.Module):
# 下面让我们看一下关于 batch 版本的 CRF
def __init__(self, tagset, start_tag, end_tag, device):
super(CRF, self).__init__()
self.tagset_size = len(tagset)
self.START_TAG_IDX = tagset.index(start_tag)
self.END_TAG_IDX = tagset.index(end_tag)
self.START_TAG_TENSOR = torch.LongTensor([self.START_TAG_IDX], device=device)
self.END_TAG_TENSOR = torch.LongTensor([self.END_TAG_IDX], device=device)
# trans: (tagset_size, tagset_size) trans (i, j) means state_i -> state_j
self.trans = nn.Parameter(
torch.randn(self.tagset_size, self.tagset_size)
)
# self.trans.data[...] = 1
self.trans.data[:, self.START_TAG_IDX] = -10000
self.trans.data[self.END_TAG_IDX, :] = -10000
self.device = device

# 初始状态
def init_alpha(self, batch_size, tagset_size):
return torch.full((batch_size, tagset_size, 1), -10000, dtype=torch.float, device=self.device)
# 做维特比解码的图,初始为word_size * tag_szie。
def init_path(self, size_shape):
# Initialization Path - LongTensor + Device + Full_value=0
return torch.full(size_shape, 0, dtype=torch.long, device=self.device)
# 重新打包,将相同位置的词汇作为一个组。
def _iter_legal_batch(self, batch_input_lens, reverse=False):
index = torch.arange(0, batch_input_lens.sum(), dtype=torch.long) # 初始了一个tensor[0,length_sum)

packed_index = rnn_utils.pack_sequence(
torch.split(index, batch_input_lens.tolist()) # split 应该是要把它找到对应的索引
) # 再 把 它打包,具体来说是让列元素对其
#print("pack_index",packed_index)
batch_iter = torch.split(packed_index.data, packed_index.batch_sizes.tolist()) # 重新组batch,按照每个text文本词的对应位置
batch_iter = reversed(batch_iter) if reverse else batch_iter
for idx in batch_iter:
yield idx, idx.size()[0]

def score_z(self, feats, batch_input_lens):
# 模拟packed pad过程
tagset_size = feats.shape[1]
batch_size = len(batch_input_lens)
alpha = self.init_alpha(batch_size, tagset_size) # batch_size,tag_size,1
alpha[:, self.START_TAG_IDX, :] = 0 # Initialization
for legal_idx, legal_batch_size in self._iter_legal_batch(batch_input_lens):
#print(feats.shape) #batch_size,tag_size
feat = feats[legal_idx, ].view(legal_batch_size, 1, tagset_size) #
# #batch * 1 * |tag| + #batch * |tag| * 1 + |tag| * |tag| = #batch * |tag| * |tag| 广播机制
legal_batch_score = feat + alpha[:legal_batch_size, ] + self.trans # 每一列是一个被转移状态的i-j的分布
alpha_new = torch.logsumexp(legal_batch_score, 1).unsqueeze(2) # batch_size,tag_size,1
alpha[:legal_batch_size, ] = alpha_new
alpha = alpha + self.trans[:, self.END_TAG_IDX].unsqueeze(1)
score = torch.logsumexp(alpha, 1).sum()
return score

def score_sentence(self, feats, batch_target):
# CRF Batched Sentence Score
# feats: (#batch_state(#words), tagset_size)
# batch_target: list<torch.LongTensor> At least One LongTensor
# Warning: words order = batch_target order
def _add_start_tag(target):
return torch.cat([self.START_TAG_TENSOR, target])

def _add_end_tag(target):
return torch.cat([target, self.END_TAG_TENSOR])

from_state = [_add_start_tag(target) for target in batch_target]
to_state = [_add_end_tag(target) for target in batch_target]

from_state = torch.cat(from_state) #拼接成一维tensor
to_state = torch.cat(to_state) #同理拼接成一维tensor
trans_score = self.trans[from_state, to_state] # 转移概率得分

gather_target = torch.cat(batch_target).view(-1, 1)
emit_score = torch.gather(feats, 1, gather_target) # 得到对应每一个标签位置的观测概率得分

return trans_score.sum() + emit_score.sum()

def viterbi(self, feats, batch_input_lens):
word_size, tagset_size = feats.shape

batch_size = len(batch_input_lens)
viterbi_path = self.init_path(feats.shape) # use feats.shape to init path.shape
alpha = self.init_alpha(batch_size, tagset_size) # batch_size, tag_size ,1
alpha[:, self.START_TAG_IDX, :] = 0 # Initialization
for legal_idx, legal_batch_size in self._iter_legal_batch(batch_input_lens):
feat = feats[legal_idx, :].view(legal_batch_size, 1, tagset_size)
legal_batch_score = feat + alpha[:legal_batch_size, ] + self.trans # batch_size, tag_size,tag_size
alpha_new, best_tag = torch.max(legal_batch_score, 1)


alpha[:legal_batch_size, ] = alpha_new.unsqueeze(2)
viterbi_path[legal_idx, ] = best_tag
alpha = alpha + self.trans[:, self.END_TAG_IDX].unsqueeze(1)
path_score, best_tag = torch.max(alpha, 1) # batch_size,1

path_score = path_score.squeeze() # path_score=#batch

best_paths = self.init_path((word_size, 1))
for legal_idx, legal_batch_size in self._iter_legal_batch(batch_input_lens, reverse=True):
best_paths[legal_idx, ] = best_tag[:legal_batch_size, ] #
backword_path = viterbi_path[legal_idx, ] # legal_size * |Tag|

this_tag = best_tag[:legal_batch_size, ] # |legal_batch_size| * 1

backword_tag = torch.gather(backword_path, 1, this_tag) # backward_path size:legal_size * |Tag|
best_tag[:legal_batch_size, ] = backword_tag
# never computing <START>

# best_paths = #words
return path_score.view(-1), best_paths.view(-1)


class BiLSTM_CRF(nn.Module):
def __init__(self, vocab_size, tagset, embedding_dim, hidden_dim,
num_layers, bidirectional, dropout, start_tag, end_tag, device, pretrained=None):
super(BiLSTM_CRF, self).__init__()
self.bilstm = BiLSTM(vocab_size, tagset, embedding_dim, hidden_dim,
num_layers, bidirectional, dropout, pretrained)
self.CRF = CRF(tagset, start_tag, end_tag, device)

def init_hidden(self, batch_size, device):
self.bilstm.hidden = self.bilstm.init_hidden(batch_size, device)

def forward(self, batch_input, batch_input_lens, batch_mask):
feats = self.bilstm(batch_input, batch_input_lens, batch_mask)
score, path = self.CRF.viterbi(feats, batch_input_lens)
return path

def neg_log_likelihood(self, batch_input, batch_input_lens, batch_mask, batch_target):
feats = self.bilstm(batch_input, batch_input_lens, batch_mask)
gold_score = self.CRF.score_sentence(feats, batch_target)
forward_score = self.CRF.score_z(feats, batch_input_lens)
return forward_score - gold_score

def predict(self, batch_input, batch_input_lens, batch_mask):
return self(batch_input, batch_input_lens, batch_mask)


class mydataset(Dataset):
def __init__(self,dataList):
self.datalst = dataList
self.word_to_ix = {}
self.tag_to_ix = {"PAD":0,"B":1,"I":2,"O":3,START_TAG:4,STOP_TAG:5}
self.data = []
self.label = []
self.input_lens = []
self.mask = []
for sentence, tags in self.datalst:
for word in sentence:
if word not in self.word_to_ix:
self.word_to_ix[word] = len(self.word_to_ix)+1
for sentence,tags in self.datalst:

self.label.append([self.tag_to_ix[ids] for ids in tags])
self.data.append([self.word_to_ix[word] for word in sentence])
lens = len(sentence)
self.mask.append([True]*lens)
self.input_lens.append(lens)


def __len__(self) -> int:
return len(self.datalst)

def __getitem__(self, index: int):
return {"label":self.label[index],"data":self.data[index],"len":self.input_lens[index],'mask':self.mask[index]}

@staticmethod
def collate_fn(all_example):

data = rnn_utils.pad_sequence(batch_first = True,padding_value = 0,sequences=[torch.tensor(dic["data"],dtype=torch.long) for dic in all_example])
label = [torch.tensor(dic["label"]) for dic in all_example]
lens = torch.tensor([dic["len"] for dic in all_example],dtype=torch.long)
mask = rnn_utils.pad_sequence(batch_first = True,padding_value = False,sequences=[torch.tensor(dic["data"],dtype=torch.bool) for dic in all_example]).reshape(-1)

return data,lens,mask,label


START_TAG = "<START>"
STOP_TAG = "<STOP"
EMBEDDING_DIM = 5
HIDDEN_DIM = 4

training_data = [(
"the wall street journal reported today that apple corporation made money".split(),
"B I I I O O O B I O O".split()
), (
"georgia tech is a university in georgia".split(),
"B I O O O O B".split()
)]
dataset = mydataset(training_data)
dataloader = DataLoader(dataset,batch_size=2,collate_fn=mydataset.collate_fn)
word_to_ix = {}
for sentence, tags in training_data:
for word in sentence:
if word not in word_to_ix:
word_to_ix[word] = len(word_to_ix)

tag_to_ix = {"B":0,"I":1,"O":2,START_TAG:3,STOP_TAG:4}
batch_size = 2
device = 'cpu'
model = BiLSTM_CRF(vocab_size=len(dataset.word_to_ix)+1,tagset=["PAD","B","I","O",START_TAG,STOP_TAG],embedding_dim=4,hidden_dim=4,num_layers=1,bidirectional=True,dropout=0.01,start_tag=START_TAG,end_tag=STOP_TAG,device='cpu')
optimizer = optim.Adam(model.parameters(),lr = 0.1, weight_decay= 1e-4)
model.init_hidden(batch_size, device)
for times in range(1):
for batch_info in dataloader:
batch_input, batch_input_lens, batch_mask, batch_target = batch_info
loss_train = model.neg_log_likelihood(batch_input, batch_input_lens, batch_mask, batch_target)
optimizer.zero_grad()
loss_train.backward()
optimizer.step()
print(loss_train.item())
model.init_hidden(batch_size, device)
for batch_info in dataloader:
batch_input, batch_input_lens, batch_mask, batch_target = batch_info
batch_pred = model.predict(batch_input, batch_input_lens, batch_mask)
print(batch_target)
print(batch_pred)
#loss_test = loss_fn(batch_input, batch_input_lens, batch_mask, batch_target)
O(∩_∩)O哈哈~