LSTM in Pytorch
[PyTorch] rnn,lstm,gru中输入输出维度 - 简书
LSTM神经网络输入输出究竟是怎样的? - 知乎
pytorch文档
可以把上面的每一列看成是有厚度的.
设置网络参数
torch.nn.LSTM( input_size, hidden_size, num_layers )
输入特征的维度 ‘num_units’
接收输入
Inputs: input, (h_0, c_0)
‘三维’‘三维’‘三维’
- input of shape (seq_len, batch, input_size) — batch 指 一个 batch 所含的序列个数, LSTM 一次处理一个 batch 的所有序列
- h_0 of shape (num_layers * num_directions, batch, hidden_size)
- c_0 of shape (num_layers * num_directions, batch, hidden_size)
输出
Outputs: output, (h_n, c_n)
- output of shape (seq_len, batch, num_directions * hidden_size)
- h_n of shape (num_layers * num_directions, batch, hidden_size)
- c_n of shape (num_layers * num_directions, batch, hidden_size)
Example
>>> rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) # 它才不管你的输入长度, 任意长度都可以
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> print(output.size())
torch.Size([5, 3, 20])
表示输出 3 个 batch, 每个 batch 形状为 [5, 20]