Spaces:
Runtime error
Runtime error
PZR0033
commited on
Commit
·
cf73df8
1
Parent(s):
5fb2a02
test env update
Browse files- rl_agent/policy.py +1 -1
- rl_agent/test_env.py +7 -4
rl_agent/policy.py
CHANGED
|
@@ -10,7 +10,7 @@ class Policy(nn.Module):
|
|
| 10 |
|
| 11 |
self.layer1 = nn.Linear(input_channels, 2 * input_channels)
|
| 12 |
self.tanh1 = nn.Tanh()
|
| 13 |
-
self.layer2 = nn.
|
| 14 |
self.tanh2 = nn.Tanh()
|
| 15 |
|
| 16 |
def forward(self, state):
|
|
|
|
| 10 |
|
| 11 |
self.layer1 = nn.Linear(input_channels, 2 * input_channels)
|
| 12 |
self.tanh1 = nn.Tanh()
|
| 13 |
+
self.layer2 = nn.Linear(2 * input_channels, 1)
|
| 14 |
self.tanh2 = nn.Tanh()
|
| 15 |
|
| 16 |
def forward(self, state):
|
rl_agent/test_env.py
CHANGED
|
@@ -4,6 +4,7 @@ from utils import myOptimizer
|
|
| 4 |
|
| 5 |
import pandas as pd
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
|
| 8 |
if __name__ == "__main__":
|
| 9 |
|
|
@@ -22,9 +23,9 @@ if __name__ == "__main__":
|
|
| 22 |
second_momentum = 0.0
|
| 23 |
transaction_cost = 0.0001
|
| 24 |
adaptation_rate = 0.01
|
| 25 |
-
state_size =
|
| 26 |
|
| 27 |
-
agent = Policy(input_channels=state_size)
|
| 28 |
optimizer = myOptimizer(learning_rate, first_momentum, second_momentum, adaptation_rate, transaction_cost)
|
| 29 |
|
| 30 |
|
|
@@ -36,10 +37,12 @@ if __name__ == "__main__":
|
|
| 36 |
|
| 37 |
env = Environment(train, history=history)
|
| 38 |
observation = env.reset()
|
| 39 |
-
for _ in range(9, 12):
|
| 40 |
|
|
|
|
|
|
|
|
|
|
| 41 |
action = agent(observation)
|
| 42 |
-
observation, reward, _ = env.step(action)
|
| 43 |
|
| 44 |
print(env.profits)
|
| 45 |
|
|
|
|
| 4 |
|
| 5 |
import pandas as pd
|
| 6 |
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
|
| 9 |
if __name__ == "__main__":
|
| 10 |
|
|
|
|
| 23 |
second_momentum = 0.0
|
| 24 |
transaction_cost = 0.0001
|
| 25 |
adaptation_rate = 0.01
|
| 26 |
+
state_size = 8
|
| 27 |
|
| 28 |
+
agent = Policy(input_channels=state_size).float()
|
| 29 |
optimizer = myOptimizer(learning_rate, first_momentum, second_momentum, adaptation_rate, transaction_cost)
|
| 30 |
|
| 31 |
|
|
|
|
| 37 |
|
| 38 |
env = Environment(train, history=history)
|
| 39 |
observation = env.reset()
|
|
|
|
| 40 |
|
| 41 |
+
for _ in range(9, 12):
|
| 42 |
+
print(type(observation))
|
| 43 |
+
observation = torch.as_tensor(observation)
|
| 44 |
action = agent(observation)
|
| 45 |
+
observation, reward, _ = env.step(action.data.numpy())
|
| 46 |
|
| 47 |
print(env.profits)
|
| 48 |
|